!19689 Fix fl namespace issue.

Merge pull request !19689 from ZPaC/fix-namespace
This commit is contained in:
i-robot 2021-07-09 08:02:48 +00:00 committed by Gitee
commit d76bb99d8a
85 changed files with 575 additions and 566 deletions

View File

@ -47,13 +47,13 @@ class FusedPullWeightKernel : public CPUKernel {
return false;
}
std::shared_ptr<ps::FBBuilder> fbb = std::make_shared<ps::FBBuilder>();
std::shared_ptr<fl::FBBuilder> fbb = std::make_shared<fl::FBBuilder>();
MS_EXCEPTION_IF_NULL(fbb);
total_iteration_++;
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
if (total_iteration_ % ps::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
ps::kTrainBeginStepNum) {
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
fl::kTrainBeginStepNum) {
return true;
}
@ -72,10 +72,10 @@ class FusedPullWeightKernel : public CPUKernel {
const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
int retcode = schema::ResponseCode_SucNotReady;
while (retcode == schema::ResponseCode_SucNotReady) {
if (!ps::worker::FLWorker::GetInstance().SendToServer(
if (!fl::worker::FLWorker::GetInstance().SendToServer(
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. This iteration is dropped.";
ps::worker::FLWorker::GetInstance().SetIterationRunning();
fl::worker::FLWorker::GetInstance().SetIterationRunning();
return true;
}
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
@ -116,7 +116,7 @@ class FusedPullWeightKernel : public CPUKernel {
}
}
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
ps::worker::FLWorker::GetInstance().SetIterationRunning();
fl::worker::FLWorker::GetInstance().SetIterationRunning();
return true;
}
@ -154,7 +154,7 @@ class FusedPullWeightKernel : public CPUKernel {
void InitSizeLists() { return; }
private:
bool BuildPullWeightReq(std::shared_ptr<ps::FBBuilder> fbb) {
bool BuildPullWeightReq(std::shared_ptr<fl::FBBuilder> fbb) {
MS_EXCEPTION_IF_NULL(fbb);
std::vector<flatbuffers::Offset<flatbuffers::String>> fbs_weight_names;
for (const std::string &weight_name : weight_full_names_) {

View File

@ -45,13 +45,13 @@ class FusedPushWeightKernel : public CPUKernel {
return false;
}
std::shared_ptr<ps::FBBuilder> fbb = std::make_shared<ps::FBBuilder>();
std::shared_ptr<fl::FBBuilder> fbb = std::make_shared<fl::FBBuilder>();
MS_EXCEPTION_IF_NULL(fbb);
total_iteration_++;
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
if (total_iteration_ % ps::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
ps::kTrainBeginStepNum) {
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
fl::kTrainBeginStepNum) {
return true;
}
@ -67,17 +67,17 @@ class FusedPushWeightKernel : public CPUKernel {
}
// The server number may change after scaling in/out.
for (uint32_t i = 0; i < ps::worker::FLWorker::GetInstance().server_num(); i++) {
for (uint32_t i = 0; i < fl::worker::FLWorker::GetInstance().server_num(); i++) {
std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr;
const schema::ResponsePushWeight *push_weight_rsp = nullptr;
int retcode = schema::ResponseCode_SucNotReady;
while (retcode == schema::ResponseCode_SucNotReady) {
if (!ps::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
ps::core::TcpUserCommand::kPushWeight,
&push_weight_rsp_msg)) {
MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i
<< " failed. This iteration is dropped.";
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
return true;
}
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
@ -105,7 +105,7 @@ class FusedPushWeightKernel : public CPUKernel {
}
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
return true;
}
@ -143,7 +143,7 @@ class FusedPushWeightKernel : public CPUKernel {
void InitSizeLists() { return; }
private:
bool BuildPushWeightReq(std::shared_ptr<ps::FBBuilder> fbb, const std::vector<AddressPtr> &weights) {
bool BuildPushWeightReq(std::shared_ptr<fl::FBBuilder> fbb, const std::vector<AddressPtr> &weights) {
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
for (size_t i = 0; i < weight_full_names_.size(); i++) {
const std::string &weight_name = weight_full_names_[i];

View File

@ -31,8 +31,8 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
int return_num = 0;
cipher_meta_storage_.RegisterClass();
const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN);
cipher_meta_storage_.RegisterPrime(ps::server::kCtxCipherPrimer, new_prime);
if (!cipher_meta_storage_.GetPrimeFromServer(ps::server::kCtxCipherPrimer, publicparam_.prime)) {
cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime);
if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) {
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
return false;
}
@ -45,7 +45,7 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
publicparam_.t = param.t;
secrets_minnums_ = param.t;
client_num_need_ = cipher_initial_client_cnt;
featuremap_ = ps::server::ModelStore::GetInstance().model_size() / sizeof(float);
featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float);
share_clients_num_need_ = cipher_share_secrets_cnt;
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
get_model_num_need_ = cipher_get_clientlist_cnt;

View File

@ -21,7 +21,7 @@ namespace mindspore {
namespace armour {
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req,
std::shared_ptr<ps::server::FBBuilder> get_exchange_keys_resp_builder) {
std::shared_ptr<fl::server::FBBuilder> get_exchange_keys_resp_builder) {
MS_LOG(INFO) << "CipherMgr::GetKeys START";
if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
@ -32,7 +32,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
// get clientlist from memory server.
std::vector<std::string> clients;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList, &clients);
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &clients);
size_t cur_clients_num = clients.size();
std::string fl_id = get_exchange_keys_req->fl_id()->str();
@ -61,7 +61,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req,
std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder) {
std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder) {
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
// step 0: judge if the input param is legal.
if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) {
@ -75,8 +75,8 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
// step 1: get clientlist and client keys from memory server.
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
std::vector<std::string> client_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList, &client_list);
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys);
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
// step2: process new item data. and update new item data to memory server.
size_t cur_clients_num = client_list.size();
@ -131,9 +131,9 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
cur_public_key.push_back(spk);
bool retcode_key =
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(ps::server::kCtxClientsKeys, fl_id, cur_public_key);
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, fl_id, cur_public_key);
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxExChangeKeysClientList, fl_id);
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
if (retcode_key && retcode_client) {
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
@ -147,7 +147,7 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
}
}
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder,
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder,
const schema::ResponseCode retcode, const std::string &reason,
const std::string &next_req_time, const int iteration) {
auto rsp_reason = exchange_keys_resp_builder->CreateString(reason);
@ -162,7 +162,7 @@ void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exc
return;
}
bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode,
bool CipherKeys::BuildGetKeys(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good) {
bool flag = true;
if (is_good) {
@ -170,7 +170,7 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
MS_LOG(INFO) << "Get Keys: ";
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys);
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
if (record_public_keys.size() < cipher_init_->client_num_need_) {
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
<< "clients num: " << cipher_init_->client_num_need_;
@ -221,8 +221,8 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
}
void CipherKeys::ClearKeys() {
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxExChangeKeysClientList);
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsKeys);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxExChangeKeysClientList);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsKeys);
}
} // namespace armour

View File

@ -45,18 +45,18 @@ class CipherKeys {
// handle the client's request of get keys.
bool GetKeys(const int cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req,
std::shared_ptr<ps::server::FBBuilder> get_exchange_keys_resp_builder);
std::shared_ptr<fl::server::FBBuilder> get_exchange_keys_resp_builder);
// handle the client's request of exchange keys.
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req,
std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder);
std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder);
// build response code of get keys.
bool BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode, const int iteration,
bool BuildGetKeys(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode, const int iteration,
const std::string &next_req_time, bool is_good);
// build response code of exchange keys.
void BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder,
void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder,
const schema::ResponseCode retcode, const std::string &reason,
const std::string &next_req_time, const int iteration);
// clear the shared memory.

View File

@ -21,16 +21,16 @@ namespace armour {
void CipherMetaStorage::GetClientSharesFromServer(
const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) {
const ps::PBMetadata &clients_shares_pb_out =
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const ps::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
const fl::PBMetadata &clients_shares_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
auto iter = clients_shares_pb.client_secret_shares().begin();
for (; iter != clients_shares_pb.client_secret_shares().end(); ++iter) {
std::string fl_id = iter->first;
const ps::SharesPb &shares_pb = iter->second;
const fl::SharesPb &shares_pb = iter->second;
std::vector<clientshare_str> encrpted_shares_new;
for (int index_shares = 0; index_shares < shares_pb.clientsharestrs_size(); ++index_shares) {
const ps::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
const fl::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
clientshare_str new_clientshare;
new_clientshare.fl_id = client_share_str_pb.fl_id();
new_clientshare.index = client_share_str_pb.index();
@ -42,8 +42,8 @@ void CipherMetaStorage::GetClientSharesFromServer(
}
void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) {
const ps::PBMetadata &client_list_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const ps::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
const fl::PBMetadata &client_list_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
std::string fl_id = client_list_pb.fl_id(i);
clients_list->push_back(fl_id);
@ -52,14 +52,14 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect
void CipherMetaStorage::GetClientKeysFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) {
const ps::PBMetadata &clients_keys_pb_out =
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const ps::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
std::string fl_id = iter->first;
ps::KeysPb keys_pb = iter->second;
fl::KeysPb keys_pb = iter->second;
std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
std::vector<std::vector<unsigned char>> cur_keys;
@ -70,9 +70,9 @@ void CipherMetaStorage::GetClientKeysFromServer(
}
bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) {
const ps::PBMetadata &clients_noises_pb_out =
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const ps::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
const fl::PBMetadata &clients_noises_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
while (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "GetClientNoisesFromServer NULL.";
std::this_thread::sleep_for(std::chrono::milliseconds(50));
@ -83,8 +83,8 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
}
bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) {
const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
ps::Prime prime_pb(prime_pb_out.prime());
const fl::PBMetadata &prime_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
fl::Prime prime_pb(prime_pb_out.prime());
std::string str = *(prime_pb.mutable_prime());
MS_LOG(INFO) << "get prime from metastorage :" << str;
@ -99,20 +99,20 @@ bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
bool retcode = true;
ps::FLId fl_id_pb;
fl::FLId fl_id_pb;
fl_id_pb.set_fl_id(fl_id);
ps::PBMetadata client_pb;
fl::PBMetadata client_pb;
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
return retcode;
}
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
MS_LOG(INFO) << "register prime: " << prime;
ps::Prime prime_id_pb;
fl::Prime prime_id_pb;
prime_id_pb.set_prime(prime);
ps::PBMetadata prime_pb;
fl::PBMetadata prime_pb;
prime_pb.mutable_prime()->MergeFrom(prime_id_pb);
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
}
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
@ -123,25 +123,25 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std
return false;
}
// update new item to memory server.
ps::KeysPb keys;
fl::KeysPb keys;
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
ps::PairClientKeys pair_client_keys_pb;
fl::PairClientKeys pair_client_keys_pb;
pair_client_keys_pb.set_fl_id(fl_id);
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
ps::PBMetadata client_and_keys_pb;
fl::PBMetadata client_and_keys_pb;
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
return retcode;
}
bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) {
// update new item to memory server.
ps::OneClientNoises noises_pb;
fl::OneClientNoises noises_pb;
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
ps::PBMetadata client_noises_pb;
fl::PBMetadata client_noises_pb;
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
return ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
return fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
}
bool CipherMetaStorage::UpdateClientShareToServer(
@ -149,10 +149,10 @@ bool CipherMetaStorage::UpdateClientShareToServer(
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
bool retcode = true;
int size_shares = shares->size();
ps::SharesPb shares_pb;
fl::SharesPb shares_pb;
for (int index = 0; index < size_shares; ++index) {
// new item
ps::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
fl::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
std::string fl_id_new = (*shares)[index]->fl_id()->str();
int index_new = (*shares)[index]->index();
auto share = (*shares)[index]->share();
@ -160,32 +160,32 @@ bool CipherMetaStorage::UpdateClientShareToServer(
client_share_str_new_p->set_fl_id(fl_id_new);
client_share_str_new_p->set_index(index_new);
}
ps::PairClientShares pair_client_shares_pb;
fl::PairClientShares pair_client_shares_pb;
pair_client_shares_pb.set_fl_id(fl_id);
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
ps::PBMetadata client_and_shares_pb;
fl::PBMetadata client_and_shares_pb;
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
return retcode;
}
void CipherMetaStorage::RegisterClass() {
ps::PBMetadata exchange_kyes_client_list;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxExChangeKeysClientList,
fl::PBMetadata exchange_kyes_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
exchange_kyes_client_list);
ps::PBMetadata clients_keys;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsKeys, clients_keys);
ps::PBMetadata reconstruct_client_list;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxReconstructClientList,
fl::PBMetadata clients_keys;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
fl::PBMetadata reconstruct_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxReconstructClientList,
reconstruct_client_list);
ps::PBMetadata clients_reconstruct_shares;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsReconstructShares,
fl::PBMetadata clients_reconstruct_shares;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsReconstructShares,
clients_reconstruct_shares);
ps::PBMetadata share_secretes_client_list;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxShareSecretsClientList,
fl::PBMetadata share_secretes_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
share_secretes_client_list);
ps::PBMetadata clients_encrypt_shares;
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsEncryptedShares,
fl::PBMetadata clients_encrypt_shares;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
clients_encrypt_shares);
}
} // namespace armour

View File

@ -101,15 +101,15 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
bool retcode = true;
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsReconstructShares,
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&reconstruct_secret_list_ori);
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys);
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
std::vector<std::string> clients_reconstruct_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxReconstructClientList,
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
&clients_reconstruct_list);
std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList,
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list);
if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() ||
record_public_keys.size() < cipher_init_->client_num_need_ ||
@ -146,7 +146,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
client_keys.clear();
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(ps::server::kCtxClientNoises, noise) == false)
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise) == false)
return false;
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
} else {
@ -159,7 +159,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
// reconstruct secrets
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req,
std::shared_ptr<ps::server::FBBuilder> reconstruct_secret_resp_builder,
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
const std::vector<std::string> &client_list) {
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
clock_t start_time = clock();
@ -178,10 +178,10 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
return false;
}
std::vector<std::string> clients_reconstruct_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxReconstructClientList,
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
&clients_reconstruct_list);
std::map<std::string, std::vector<clientshare_str>> clients_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsReconstructShares,
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&clients_shares_all);
size_t count_client_num = clients_shares_all.size();
@ -215,9 +215,9 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
}
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxReconstructClientList, fl_id);
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id);
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
ps::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
if (!(retcode_client && retcode_share)) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
"reconstruct update shares or client failed.", cur_iterator, next_req_time);
@ -233,9 +233,9 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
return true;
} else {
bool retcode_result = true;
const ps::PBMetadata &clients_noises_pb_out =
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(ps::server::kCtxClientNoises);
const ps::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
const fl::PBMetadata &clients_noises_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
if (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "Success,the secret will be reconstructed.";
retcode_result = ReconstructSecretsGenNoise(client_list);
@ -279,13 +279,13 @@ bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
void CipherReconStruct::ClearReconstructSecrets() {
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets START";
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxReconstructClientList);
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsReconstructShares);
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientNoises);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxReconstructClientList);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsReconstructShares);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientNoises);
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success";
}
void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<ps::server::FBBuilder> fbb,
void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb,
const schema::ResponseCode retcode, const std::string &reason,
const int iteration, const std::string &next_req_time) {
auto fbs_reason = fbb->CreateString(reason);

View File

@ -44,11 +44,11 @@ class CipherReconStruct {
// reconstruct secret mask
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req,
std::shared_ptr<ps::server::FBBuilder> reconstruct_secret_resp_builder,
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
const std::vector<std::string> &client_list);
// build response code of reconstruct secret.
void BuildReconstructSecretsRsp(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode,
void BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason, const int iteration, const std::string &next_req_time);
// clear the shared memory.

View File

@ -21,7 +21,7 @@
namespace mindspore {
namespace armour {
bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder,
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
const string next_req_time) {
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
if (share_secrets_req == nullptr) {
@ -35,13 +35,13 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
// step 1: get client list and share secrets from memory server.
clock_t start_time = clock();
std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList,
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list);
std::vector<std::string> clients_exchange_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList,
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList,
&clients_exchange_list);
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsEncryptedShares,
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
&encrypted_shares_all);
MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size()
@ -75,9 +75,9 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
(share_secrets_req->encrypted_shares());
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
ps::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxShareSecretsClientList, fl_id_src);
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src);
bool retcode = retcode_share && retcode_client;
if (retcode) {
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
@ -95,7 +95,7 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
}
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder,
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
const std::string &next_req_time) {
MS_LOG(INFO) << "CipherShares::GetSecrets START";
clock_t start_time = clock();
@ -108,10 +108,10 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
// step 1: get client list and client shares list from memory server.
std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList,
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list);
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsEncryptedShares,
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
&encrypted_shares_all);
int iteration = get_secrets_req->iteration();
size_t share_clients_num = clients_share_list.size();
@ -180,7 +180,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
}
void CipherShares::BuildGetSecretsRsp(
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) {
int rsp_retcode = retcode;
int rsp_iteration = iteration;
@ -199,7 +199,7 @@ void CipherShares::BuildGetSecretsRsp(
return;
}
void CipherShares::BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder,
void CipherShares::BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
const schema::ResponseCode retcode, const string &reason,
const string &next_req_time, const int iteration) {
auto rsp_reason = share_secrets_resp_builder->CreateString(reason);
@ -211,8 +211,8 @@ void CipherShares::BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> s
}
void CipherShares::ClearShareSecrets() {
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxShareSecretsClientList);
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsEncryptedShares);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxShareSecretsClientList);
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsEncryptedShares);
}
} // namespace armour

View File

@ -43,17 +43,17 @@ class CipherShares {
// handle the client's request of share secrets.
bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder, const string next_req_time);
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, const string next_req_time);
// handle the client's request of get secrets.
bool GetSecrets(const schema::GetShareSecrets *get_secrets_req,
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time);
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time);
// build response code of share secrets.
void BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder,
void BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
const schema::ResponseCode retcode, const string &reason, const string &next_req_time,
const int iteration);
// build response code of get secrets.
void BuildGetSecretsRsp(std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder,
void BuildGetSecretsRsp(std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
const schema::ResponseCode retcode, const int iteration, std::string next_req_time,
std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares);
// clear the shared memory.

View File

@ -26,13 +26,13 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
clock_t start_time = clock();
std::vector<float> noise;
cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(ps::server::kCtxClientNoises, &noise);
cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
if (noise.size() != cipher_init_->featuremap_) {
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
return false;
}
size_t data_size = ps::server::LocalMetaStore::GetInstance().value<size_t>(ps::server::kCtxFedAvgTotalDataSize);
size_t data_size = fl::server::LocalMetaStore::GetInstance().value<size_t>(fl::server::kCtxFedAvgTotalDataSize);
int sum_size = 0;
for (auto iter = data.begin(); iter != data.end(); ++iter) {
int size_data = iter->second->size / sizeof(float);

View File

@ -17,13 +17,13 @@
#include "fl/server/collective_ops_impl.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
local_rank_ = server_node_->rank_id();
server_num_ = PSContext::instance()->initial_server_num();
server_num_ = ps::PSContext::instance()->initial_server_num();
return;
}
@ -66,7 +66,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
// Step 1: Async send data to next rank.
size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size;
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
chunk_sizes[send_chunk_index] * sizeof(T));
// Step 2: Async receive data to next rank and wait until it's done.
size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size;
@ -76,7 +76,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
@ -104,7 +104,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
for (size_t i = 0; i < rank_size - 1; i++) {
size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size;
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
chunk_sizes[send_chunk_index] * sizeof(T));
size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size;
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
@ -113,7 +113,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
@ -151,7 +151,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
for (uint32_t i = 1; i < rank_size; i++) {
std::shared_ptr<std::vector<unsigned char>> recv_str;
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str);
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
@ -167,7 +167,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
}
} else {
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
if (!server_node_->Wait(send_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
@ -180,7 +180,8 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
if (local_rank_ == 0) {
for (uint32_t i = 1; i < rank_size; i++) {
MS_LOG(DEBUG) << "Broadcast data to process " << i;
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
auto send_req_id =
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
if (!server_node_->Wait(send_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
return false;
@ -189,7 +190,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
} else {
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str);
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
return false;
@ -247,5 +248,5 @@ template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *re
template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
#define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
#define MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
#include <memory>
#include <string>
@ -27,7 +27,7 @@
#include "fl/server/common.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// CollectiveOpsImpl is the collective communication API of the server.
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
@ -39,7 +39,7 @@ class CollectiveOpsImpl {
return instance;
}
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
template <typename T>
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
@ -48,7 +48,7 @@ class CollectiveOpsImpl {
bool ReInitForScaling();
private:
CollectiveOpsImpl() = default;
CollectiveOpsImpl() : server_node_(nullptr), local_rank_(0), server_num_(0) {}
~CollectiveOpsImpl() = default;
CollectiveOpsImpl(const CollectiveOpsImpl &) = delete;
CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete;
@ -61,7 +61,7 @@ class CollectiveOpsImpl {
template <typename T>
bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count);
std::shared_ptr<core::ServerNode> server_node_;
std::shared_ptr<ps::core::ServerNode> server_node_;
uint32_t local_rank_;
uint32_t server_num_;
@ -69,6 +69,6 @@ class CollectiveOpsImpl {
std::mutex mtx_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
#define MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
#define MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
#include <map>
#include <string>
@ -37,7 +37,7 @@
#include "ps/core/communicator/message_handler.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// Definitions for the server framework.
enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
@ -73,7 +73,7 @@ using TimeOutCb = std::function<void(bool, const std::string &)>;
using StopTimerCb = std::function<void(void)>;
using FinishIterCb = std::function<void(bool, const std::string &)>;
using FinalizeCb = std::function<void(void)>;
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
using MessageCallback = std::function<void(const std::shared_ptr<ps::core::MessageHandler> &)>;
// Information about whether server kernel will reuse kernel node memory from the front end.
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
@ -237,6 +237,6 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size
// Definitions for Parameter Server.
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_COMMON_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/consistent_hash_ring.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
bool ConsistentHashRing::Insert(uint32_t rank) {
for (uint32_t i = 0; i < virtual_node_num_; i++) {
@ -53,5 +53,5 @@ uint32_t ConsistentHashRing::Find(const std::string &key) {
return iterator->second;
}
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,15 +14,15 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
#define MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
#define MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
#include <map>
#include <string>
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in
// server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node.
@ -59,6 +59,6 @@ class ConsistentHashRing {
std::map<size_t, uint32_t> ring_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_

View File

@ -20,19 +20,19 @@
#include <vector>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node,
void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node,
uint32_t counting_server_rank) {
server_node_ = server_node;
MS_EXCEPTION_IF_NULL(server_node_);
local_rank_ = server_node_->rank_id();
server_num_ = PSContext::instance()->initial_server_num();
server_num_ = ps::PSContext::instance()->initial_server_num();
counting_server_rank_ = counting_server_rank;
return;
}
void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
communicator_ = communicator;
MS_EXCEPTION_IF_NULL(communicator_);
communicator_->RegisterMsgCallBack(
@ -94,7 +94,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
report_count_req.set_id(id);
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount,
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
&report_cnt_rsp_msg)) {
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
return false;
@ -126,7 +126,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
return false;
}
@ -165,7 +165,7 @@ bool DistributedCountService::ReInitForScaling() {
return true;
}
void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) {
void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
@ -214,7 +214,8 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::Mes
return;
}
void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) {
void DistributedCountService::HandleCountReachThresholdRequest(
const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
@ -237,7 +238,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
return;
}
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) {
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
@ -290,7 +291,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
// Broadcast to all follower servers.
for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) {
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
return false;
}
@ -308,7 +309,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
// Broadcast to all follower servers.
for (uint32_t i = 1; i < server_num_; i++) {
if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) {
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
return false;
}
@ -318,5 +319,5 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
return true;
}
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
#define MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
#include <set>
#include <string>
@ -27,7 +27,7 @@
#include "ps/core/communicator/tcp_communicator.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
constexpr uint32_t kDefaultCountingServerRank = 0;
constexpr auto kModuleDistributedCountService = "DistributedCountService";
@ -54,10 +54,10 @@ class DistributedCountService {
}
// Initialize counter service with the server node because communication is needed.
void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank);
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node, uint32_t counting_server_rank);
// Register message callbacks of the counting server to handle messages sent by the other servers.
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
// Register counter to the counting server for the name with its threshold count in server cluster dimension and
// first/last count event callbacks.
@ -87,15 +87,15 @@ class DistributedCountService {
DistributedCountService &operator=(const DistributedCountService &) = delete;
// Callback for the reporting count message from other servers. Only counting server will call this method.
void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message);
void HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Callback for the querying whether threshold count is reached message from other servers. Only counting
// server will call this method.
void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message);
void HandleCountReachThresholdRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Callback for the first/last event message from the counting server. Only other servers will call this
// method.
void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message);
void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
// Call the callbacks when the first/last count event is triggered.
bool TriggerCounterEvent(const std::string &name);
@ -103,8 +103,8 @@ class DistributedCountService {
bool TriggerLastCountEvent(const std::string &name);
// Members for the communication between counting server and other servers.
std::shared_ptr<core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_;
std::shared_ptr<ps::core::ServerNode> server_node_;
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
uint32_t local_rank_;
uint32_t server_num_;
@ -126,6 +126,6 @@ class DistributedCountService {
std::unordered_map<std::string, std::mutex> mutex_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_

View File

@ -20,18 +20,18 @@
#include <vector>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
server_node_ = server_node;
MS_EXCEPTION_IF_NULL(server_node);
local_rank_ = server_node_->rank_id();
server_num_ = PSContext::instance()->initial_server_num();
server_num_ = ps::PSContext::instance()->initial_server_num();
InitHashRing();
return;
}
void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
communicator_ = communicator;
MS_EXCEPTION_IF_NULL(communicator_);
communicator_->RegisterMsgCallBack(
@ -100,7 +100,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
metadata_with_name.set_name(name);
*metadata_with_name.mutable_metadata() = meta;
std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata,
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
&update_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
return false;
@ -133,7 +133,7 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
PBMetadata get_metadata_rsp;
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata,
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetMetadata,
&get_meta_rsp_msg)) {
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
return get_metadata_rsp;
@ -174,7 +174,7 @@ void DistributedMetadataStore::InitHashRing() {
return;
}
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
@ -196,7 +196,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
return;
}
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
@ -267,7 +267,7 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
auto &fl_id = meta.pair_client_shares().fl_id();
auto &client_shares = meta.pair_client_shares().client_shares();
// google::protobuf::Map< std::string, mindspore::ps::core::SharesPb >::const_iterator iter;
// google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
// Check whether the new item already exists.
bool add_flag = true;
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); iter++) {
@ -299,5 +299,5 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
return true;
}
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
#define MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
#include <string>
#include <memory>
@ -27,7 +27,7 @@
#include "fl/server/consistent_hash_ring.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
constexpr auto kModuleDistributedMetadataStore = "DistributedMetadataStore";
// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly
@ -44,10 +44,10 @@ class DistributedMetadataStore {
}
// Initialize metadata storage with the server node because communication is needed.
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
// Register callbacks for the server to handle update/get metadata messages from other servers.
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
// Register metadata for the name with the initial value. This method should be only called once for each name.
void RegisterMetadata(const std::string &name, const PBMetadata &meta);
@ -65,7 +65,13 @@ class DistributedMetadataStore {
bool ReInitForScaling();
private:
DistributedMetadataStore() = default;
DistributedMetadataStore()
: server_node_(nullptr),
communicator_(nullptr),
local_rank_(0),
server_num_(0),
router_(nullptr),
metadata_({}) {}
~DistributedMetadataStore() = default;
DistributedMetadataStore(const DistributedMetadataStore &) = delete;
DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete;
@ -74,17 +80,17 @@ class DistributedMetadataStore {
void InitHashRing();
// Callback for updating metadata request sent to the server.
void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
void HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Callback for getting metadata request sent to the server.
void HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
void HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Do updating metadata in the server where the metadata for the name is stored.
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
// Members for the communication between servers.
std::shared_ptr<core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_;
std::shared_ptr<ps::core::ServerNode> server_node_;
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
uint32_t local_rank_;
uint32_t server_num_;
@ -100,6 +106,6 @@ class DistributedMetadataStore {
std::unordered_map<std::string, std::mutex> mutex_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_

View File

@ -21,7 +21,7 @@
#include <vector>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -320,5 +320,5 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
return true;
}
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
#define MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
#define MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
#include <map>
#include <set>
@ -31,7 +31,7 @@
#endif
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
// logics relevant to kernel launching.
@ -94,7 +94,7 @@ class Executor {
bool Unmask();
private:
Executor() {}
Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}) {}
~Executor() = default;
Executor(const Executor &) = delete;
Executor &operator=(const Executor &) = delete;
@ -126,6 +126,6 @@ class Executor {
#endif
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_

View File

@ -23,10 +23,10 @@
#include "fl/server/server.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
class Server;
void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
communicator_->RegisterMsgCallBack("syncIteration",
@ -42,12 +42,12 @@ void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunica
std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1));
}
void Iteration::RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node) {
void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) {
MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node;
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning),
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
std::bind(&Iteration::HandleIterationRunningEvent, this));
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted),
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
std::bind(&Iteration::HandleIterationCompletedEvent, this));
}
@ -56,7 +56,7 @@ void Iteration::AddRound(const std::shared_ptr<Round> &round) {
rounds_.push_back(round);
}
void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
if (communicators.empty()) {
MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
@ -64,7 +64,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
}
std::for_each(communicators.begin(), communicators.end(),
[&](const std::shared_ptr<core::CommunicatorBase> &communicator) {
[&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
for (auto &round : rounds_) {
if (round == nullptr) {
continue;
@ -120,7 +120,7 @@ void Iteration::SetIterationRunning() {
}
if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationRunning));
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning));
}
iteration_state_ = IterationState::kRunning;
}
@ -133,7 +133,7 @@ void Iteration::SetIterationCompleted() {
}
if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationCompleted));
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted));
}
iteration_state_ = IterationState::kCompleted;
}
@ -171,7 +171,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
sync_iter_req.set_rank(rank);
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, core::TcpUserCommand::kSyncIteration,
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration,
&sync_iter_rsp_msg)) {
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
return false;
@ -189,7 +189,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
return true;
}
void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message) {
void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
@ -224,14 +224,14 @@ bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const s
notify_leader_to_next_iter_req.set_iter_num(iteration_num_);
notify_leader_to_next_iter_req.set_reason(reason);
if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank,
core::TcpUserCommand::kNotifyLeaderToNextIter)) {
ps::core::TcpUserCommand::kNotifyLeaderToNextIter)) {
MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed.";
return false;
}
return true;
}
void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
return;
}
@ -278,7 +278,7 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
std::vector<uint32_t> offline_servers = {};
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(prepare_next_iter_req, i, core::TcpUserCommand::kPrepareForNextIter)) {
if (!communicator_->SendPbRequest(prepare_next_iter_req, i, ps::core::TcpUserCommand::kPrepareForNextIter)) {
MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later.";
offline_servers.push_back(i);
continue;
@ -289,17 +289,18 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
// Should avoid endless loop if the server communicator is stopped.
while (communicator_->running() &&
!communicator_->SendPbRequest(prepare_next_iter_req, rank, core::TcpUserCommand::kPrepareForNextIter)) {
!communicator_->SendPbRequest(prepare_next_iter_req, rank, ps::core::TcpUserCommand::kPrepareForNextIter)) {
MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
<< " failed. The server has not recovered yet.";
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));
}
MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
});
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
return true;
}
void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
return;
}
@ -329,7 +330,7 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
proceed_to_next_iter_req.set_last_iter_num(iteration_num_);
proceed_to_next_iter_req.set_reason(reason);
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, core::TcpUserCommand::kProceedToNextIter)) {
if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, ps::core::TcpUserCommand::kProceedToNextIter)) {
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
continue;
}
@ -339,7 +340,7 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
return true;
}
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
return;
}
@ -388,7 +389,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
EndLastIterRequest end_last_iter_req;
end_last_iter_req.set_last_iter_num(last_iter_num);
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
if (!communicator_->SendPbRequest(end_last_iter_req, i, core::TcpUserCommand::kEndLastIter)) {
if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
continue;
}
@ -398,7 +399,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
return true;
}
void Iteration::HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
return;
}
@ -429,9 +430,9 @@ void Iteration::EndLastIter() {
MS_LOG(INFO) << "End the last iteration " << iteration_num_;
iteration_num_++;
// After the job is done, reset the iteration to the initial number and reset ModelStore.
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
if (iteration_num_ > ps::PSContext::instance()->fl_iteration_num()) {
MS_LOG(INFO) << "Iteration loop " << iteration_loop_count_
<< " is completed. Iteration number: " << PSContext::instance()->fl_iteration_num();
<< " is completed. Iteration number: " << ps::PSContext::instance()->fl_iteration_num();
iteration_num_ = 1;
iteration_loop_count_++;
ModelStore::GetInstance().Reset();
@ -444,5 +445,5 @@ void Iteration::EndLastIter() {
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
}
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
#define MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
#include <memory>
#include <vector>
@ -26,7 +26,7 @@
#include "fl/server/local_meta_store.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
enum class IterationState {
// This iteration is still in process.
@ -48,16 +48,16 @@ class Iteration {
}
// Register callbacks for other servers to synchronize iteration information from leader server.
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
// Register event callbacks for iteration state synchronization.
void RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node);
void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node);
// Add a round for the iteration. This method will be called multiple times for each round.
void AddRound(const std::shared_ptr<Round> &round);
// Initialize all the rounds in the iteration.
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
void InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
// This method will control servers to proceed to next iteration.
@ -104,7 +104,7 @@ class Iteration {
// Synchronize iteration form the leader server(Rank 0).
bool SyncIteration(uint32_t rank);
void HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message);
void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// The request for moving to next iteration is not reentrant.
bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num);
@ -112,28 +112,28 @@ class Iteration {
// The methods for moving to next iteration for all the servers.
// Step 1: follower servers notify leader server that they need to move to next iteration.
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode.
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
void HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// The server prepare for the next iteration. This method will switch the server to safemode.
void PrepareForNextIter();
// Step 3: leader server broadcast to all follower servers to move to next iteration.
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
void HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Move to next iteration. Store last iterations model and reset all the rounds.
void Next(bool is_iteration_valid, const std::string &reason);
// Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode.
bool BroadcastEndLastIterRequest(uint64_t iteration_num);
void HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message);
void HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// The server end the last iteration. This method will increase the iteration number and cancel the safemode.
void EndLastIter();
std::shared_ptr<core::ServerNode> server_node_;
std::shared_ptr<core::TcpCommunicator> communicator_;
std::shared_ptr<ps::core::ServerNode> server_node_;
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
// All the rounds in the server.
std::vector<std::shared_ptr<Round>> rounds_;
@ -155,6 +155,6 @@ class Iteration {
std::mutex pinned_mtx_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/iteration_timer.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
void IterationTimer::Start(const std::chrono::milliseconds &duration) {
if (running_.load()) {
@ -52,5 +52,5 @@ bool IterationTimer::IsTimeOut(const std::chrono::milliseconds &timestamp) const
bool IterationTimer::IsRunning() const { return running_; }
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
#define MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
#include <chrono>
#include <atomic>
@ -24,7 +24,7 @@
#include "fl/server/common.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// IterationTimer controls the time window for the purpose of eliminating trailing time of each iteration.
class IterationTimer {
@ -59,6 +59,6 @@ class IterationTimer {
TimeOutCb timeout_callback_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
#include <memory>
#include <string>
@ -26,7 +26,7 @@
#include "fl/server/kernel/params_info.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
// AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation.
@ -99,6 +99,6 @@ class AggregationKernel : public CPUKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_

View File

@ -18,7 +18,7 @@
#include <utility>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
bool AggregationKernelFactory::Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) {
@ -67,5 +67,5 @@ bool AggregationKernelFactory::Matched(const ParamsInfo &params_info, const CNod
}
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
#include <memory>
#include <string>
@ -24,7 +24,7 @@
#include "fl/server/kernel/aggregation_kernel.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>;
@ -51,6 +51,7 @@ class AggregationKernelRegister {
AggregationKernelCreator &&creator) {
AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
}
~AggregationKernelRegister() = default;
};
// Register aggregation kernel with one template type T.
@ -66,6 +67,6 @@ class AggregationKernelRegister {
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); });
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/apply_momentum_kernel.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
REG_OPTIMIZER_KERNEL(ApplyMomentum,
@ -30,5 +30,5 @@ REG_OPTIMIZER_KERNEL(ApplyMomentum,
ApplyMomentumKernel, float)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
#include <vector>
#include <memory>
@ -25,7 +25,7 @@
#include "fl/server/kernel/optimizer_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
using mindspore::kernel::ApplyMomentumCPUKernel;
@ -57,6 +57,6 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/dense_grad_accum_kernel.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
REG_AGGREGATION_KERNEL(
@ -26,5 +26,5 @@ REG_AGGREGATION_KERNEL(
DenseGradAccumKernel, float)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
#include <memory>
#include <string>
@ -26,7 +26,7 @@
#include "fl/server/kernel/aggregation_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
template <typename T>
@ -90,6 +90,6 @@ class DenseGradAccumKernel : public AggregationKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/fed_avg_kernel.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
REG_AGGREGATION_KERNEL_TWO(FedAvg,
@ -29,5 +29,5 @@ REG_AGGREGATION_KERNEL_TWO(FedAvg,
FedAvgKernel, float, size_t)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
#include <memory>
#include <string>
@ -31,7 +31,7 @@
#include "fl/server/kernel/aggregation_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from
@ -42,7 +42,13 @@ namespace kernel {
template <typename T, typename S>
class FedAvgKernel : public AggregationKernel {
public:
FedAvgKernel() : participated_(false) {}
FedAvgKernel()
: cnode_weight_idx_(0),
weight_addr_(nullptr),
data_size_addr_(nullptr),
new_weight_addr_(nullptr),
new_data_size_addr_(nullptr),
participated_(false) {}
~FedAvgKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override {
@ -68,13 +74,13 @@ class FedAvgKernel : public AggregationKernel {
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first;
MS_EXCEPTION_IF_NULL(weight_node);
name_ = cnode_name + "." + weight_node->fullname_with_scope();
first_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) {
first_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
std::unique_lock<std::mutex> lock(weight_mutex_);
if (!participated_) {
ClearWeightAndDataSize();
}
};
last_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) {
last_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
size_t weight_size = weight_addr_->size;
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
@ -193,7 +199,7 @@ class FedAvgKernel : public AggregationKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
#include <memory>
#include <string>
@ -26,7 +26,7 @@
#include "fl/server/kernel/params_info.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
// KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory
@ -87,6 +87,6 @@ class KernelFactory {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
#include <memory>
#include <string>
@ -28,7 +28,7 @@
#include "fl/server/kernel/params_info.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
using mindspore::kernel::IsSameShape;
@ -92,6 +92,6 @@ class OptimizerKernel : public CPUKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_

View File

@ -18,7 +18,7 @@
#include <utility>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
bool OptimizerKernelFactory::Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) {
@ -66,5 +66,5 @@ bool OptimizerKernelFactory::Matched(const ParamsInfo &params_info, const CNodeP
}
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
#include <memory>
#include <string>
@ -24,7 +24,7 @@
#include "fl/server/kernel/optimizer_kernel.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
using OptimizerKernelCreator = std::function<std::shared_ptr<OptimizerKernel>()>;
@ -50,6 +50,7 @@ class OptimizerKernelRegister {
OptimizerKernelRegister(const std::string &name, const ParamsInfo &params_info, OptimizerKernelCreator &&creator) {
OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
}
~OptimizerKernelRegister() = default;
};
// Register optimizer kernel with one template type T.
@ -59,6 +60,6 @@ class OptimizerKernelRegister {
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); });
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_

View File

@ -18,7 +18,7 @@
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
ParamsInfo &ParamsInfo::AddInputNameType(const std::string &name, TypeId type) {
@ -64,5 +64,5 @@ const std::vector<std::string> &ParamsInfo::workspace_names() const { return wor
const std::vector<std::string> &ParamsInfo::outputs_names() const { return outputs_names_; }
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
#include <utility>
#include <string>
@ -23,7 +23,7 @@
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
// ParamsInfo is used for server computation kernel's register, e.g, ApplyMomentumKernel, FedAvgKernel, etc.
@ -65,6 +65,6 @@ class ParamsInfo {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_

View File

@ -22,7 +22,7 @@
#include "schema/cipher_generated.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void ClientListKernel::InitKernel(size_t) {
@ -150,7 +150,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
return true;
} // namespace ps
} // namespace fl
bool ClientListKernel::Reset() {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
@ -196,5 +196,5 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli
REG_ROUND_KERNEL(getClientList, ClientListKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
#include <string>
#include <vector>
#include <memory>
@ -26,7 +26,7 @@
#include "fl/server/executor.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
class ClientListKernel : public RoundKernel {
@ -50,6 +50,6 @@ class ClientListKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H

View File

@ -20,7 +20,7 @@
#include <memory>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void ExchangeKeysKernel::InitKernel(size_t) {
@ -100,5 +100,5 @@ bool ExchangeKeysKernel::Reset() {
REG_ROUND_KERNEL(exchangeKeys, ExchangeKeysKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#include <vector>
#include "fl/server/common.h"
@ -25,7 +25,7 @@
#include "fl/armour/cipher/cipher_keys.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
class ExchangeKeysKernel : public RoundKernel {
@ -44,7 +44,7 @@ class ExchangeKeysKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H

View File

@ -19,7 +19,7 @@
#include <memory>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void GetKeysKernel::InitKernel(size_t) {
@ -99,5 +99,5 @@ bool GetKeysKernel::Reset() {
REG_ROUND_KERNEL(getKeys, GetKeysKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
#include <vector>
#include "fl/server/common.h"
@ -25,7 +25,7 @@
#include "fl/armour/cipher/cipher_keys.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
class GetKeysKernel : public RoundKernel {
@ -44,7 +44,7 @@ class GetKeysKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H

View File

@ -23,7 +23,7 @@
#include "fl/server/model_store.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void GetModelKernel::InitKernel(size_t) {
@ -133,5 +133,5 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con
REG_ROUND_KERNEL(getModel, GetModelKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_MODEL_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_MODEL_KERNEL_H_
#include <map>
#include <memory>
@ -27,13 +27,13 @@
#include "fl/server/kernel/round/round_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
constexpr uint32_t kPrintGetModelForEveryRetryTime = 50;
class GetModelKernel : public RoundKernel {
public:
GetModelKernel() = default;
GetModelKernel() : executor_(nullptr), iteration_time_window_(0), retry_count_(0) {}
~GetModelKernel() override = default;
void InitKernel(size_t) override;
@ -58,6 +58,6 @@ class GetModelKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_

View File

@ -21,7 +21,7 @@
#include "fl/armour/cipher/cipher_shares.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void GetSecretsKernel::InitKernel(size_t) {
@ -102,5 +102,5 @@ bool GetSecretsKernel::Reset() {
REG_ROUND_KERNEL(getSecrets, GetSecretsKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#include <vector>
#include "fl/server/common.h"
@ -25,7 +25,7 @@
#include "fl/server/executor.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
class GetSecretsKernel : public RoundKernel {
@ -44,7 +44,7 @@ class GetSecretsKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H

View File

@ -22,7 +22,7 @@
#include "fl/server/model_store.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void PullWeightKernel::InitKernel(size_t) {
@ -137,5 +137,5 @@ void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const
REG_ROUND_KERNEL(pullWeight, PullWeightKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
#include <map>
#include <memory>
@ -27,7 +27,7 @@
#include "fl/server/executor.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
constexpr uint32_t kPrintPullWeightForEveryRetryTime = 500;
@ -53,6 +53,6 @@ class PullWeightKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/round/push_weight_kernel.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void PushWeightKernel::InitKernel(size_t) {
@ -60,8 +60,8 @@ bool PushWeightKernel::Reset() {
return true;
}
void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) {
if (PSContext::instance()->resetter_round() == ResetterRound::kPushWeight) {
void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushWeight) {
FinishIteration();
}
return;
@ -136,5 +136,5 @@ void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const
REG_ROUND_KERNEL(pushWeight, PushWeightKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
#include <map>
#include <memory>
@ -27,7 +27,7 @@
#include "fl/server/executor.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
class PushWeightKernel : public RoundKernel {
@ -39,7 +39,7 @@ class PushWeightKernel : public RoundKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private:
bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
@ -52,6 +52,6 @@ class PushWeightKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_

View File

@ -20,7 +20,7 @@
#include <memory>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
@ -34,17 +34,17 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
MS_LOG(INFO) << "start FinishIteration";
FinishIteration();
MS_LOG(INFO) << "end FinishIteration";
return;
};
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) { return; };
auto first_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) { return; };
name_unmask_ = "UnMaskKernel";
MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : "
<< LocalMetaStore::GetInstance().curr_iter_num();
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, PSContext::instance()->initial_server_num(),
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, ps::PSContext::instance()->initial_server_num(),
{first_cnt_handler, last_cnt_handler});
}
@ -134,9 +134,9 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
return true;
}
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
if (PSContext::instance()->encrypt_type() == kPWEncryptType) {
if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) {
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
@ -164,5 +164,5 @@ bool ReconstructSecretsKernel::Reset() {
REG_ROUND_KERNEL(reconstructSecrets, ReconstructSecretsKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
#include <vector>
#include <memory>
@ -27,7 +27,7 @@
#include "fl/server/executor.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
class ReconstructSecretsKernel : public RoundKernel {
@ -39,7 +39,7 @@ class ReconstructSecretsKernel : public RoundKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private:
std::string name_unmask_;
@ -49,6 +49,6 @@ class ReconstructSecretsKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_

View File

@ -24,7 +24,7 @@
#include <vector>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) {
@ -61,9 +61,9 @@ RoundKernel::~RoundKernel() {
}
}
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) { return; }
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
void RoundKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) { return; }
void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
void RoundKernel::StopTimer() const {
if (stop_timer_cb_) {
@ -129,5 +129,5 @@ void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const v
}
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#include <map>
#include <memory>
@ -35,7 +35,7 @@
#include "fl/server/distributed_metadata_store.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
@ -67,8 +67,8 @@ class RoundKernel : virtual public CPUKernel {
// The counter event handlers for DistributedCountService.
// The callbacks when first message and last message for this round kernel is received.
// These methods is called by class DistributedCountService and triggered by counting server.
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
virtual void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
virtual void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
// Called when this round is finished. This round timer's Stop method will be called.
void StopTimer() const;
@ -123,6 +123,6 @@ class RoundKernel : virtual public CPUKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/kernel/round/round_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
RoundKernelFactory &RoundKernelFactory::GetInstance() {
@ -40,5 +40,5 @@ std::shared_ptr<RoundKernel> RoundKernelFactory::Create(const std::string &name)
}
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
#include <memory>
#include <string>
@ -25,7 +25,7 @@
#include "fl/server/kernel/round/round_kernel.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>;
@ -50,6 +50,7 @@ class RoundKernelRegister {
RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) {
RoundKernelFactory::GetInstance().Register(name, std::move(creator));
}
~RoundKernelRegister() = default;
};
#define REG_ROUND_KERNEL(NAME, CLASS) \
@ -57,6 +58,6 @@ class RoundKernelRegister {
static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); });
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_

View File

@ -19,7 +19,7 @@
#include <memory>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void ShareSecretsKernel::InitKernel(size_t) {
@ -101,5 +101,5 @@ bool ShareSecretsKernel::Reset() {
REG_ROUND_KERNEL(shareSecrets, ShareSecretsKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#include <vector>
#include "fl/server/common.h"
@ -25,7 +25,7 @@
#include "fl/armour/cipher/cipher_shares.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
class ShareSecretsKernel : public RoundKernel {
@ -44,7 +44,7 @@ class ShareSecretsKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H

View File

@ -26,7 +26,7 @@
#endif
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void StartFLJobKernel::InitKernel(size_t) {
@ -113,8 +113,8 @@ bool StartFLJobKernel::Reset() {
return true;
}
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
// The first startFLJob request means a new iteration starts running.
Iteration::GetInstance().SetIterationRunning();
@ -194,8 +194,8 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
std::map<std::string, AddressPtr> feature_maps) {
auto fbs_reason = fbb->CreateString(reason);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode());
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
auto fbs_server_mode = fbb->CreateString(ps::PSContext::instance()->server_mode());
auto fbs_fl_name = fbb->CreateString(ps::PSContext::instance()->fl_name());
#ifdef ENABLE_ARMOUR
auto *param = armour::CipherInit::GetInstance().GetPublicParams();
@ -206,7 +206,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
float dp_eps = param->dp_eps;
float dp_delta = param->dp_delta;
float dp_norm_clip = param->dp_norm_clip;
auto encrypt_type = fbb->CreateString(PSContext::instance()->encrypt_type());
auto encrypt_type = fbb->CreateString(ps::PSContext::instance()->encrypt_type());
auto cipher_public_params =
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
@ -215,10 +215,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
fl_plan_builder.add_fl_name(fbs_fl_name);
fl_plan_builder.add_server_mode(fbs_server_mode);
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate());
fl_plan_builder.add_iterations(ps::PSContext::instance()->fl_iteration_num());
fl_plan_builder.add_epochs(ps::PSContext::instance()->client_epoch_num());
fl_plan_builder.add_mini_batch(ps::PSContext::instance()->client_batch_size());
fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate());
#ifdef ENABLE_ARMOUR
fl_plan_builder.add_cipher(cipher_public_params);
#endif
@ -250,5 +250,5 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
REG_ROUND_KERNEL(startFLJob, StartFLJobKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
#include <map>
#include <memory>
@ -27,7 +27,7 @@
#include "fl/server/kernel/round/round_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
class StartFLJobKernel : public RoundKernel {
@ -40,7 +40,7 @@ class StartFLJobKernel : public RoundKernel {
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private:
// Returns whether the startFLJob count of this iteration has reached the threshold.
@ -74,6 +74,6 @@ class StartFLJobKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_

View File

@ -21,7 +21,7 @@
#include "fl/server/kernel/round/update_model_kernel.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
void UpdateModelKernel::InitKernel(size_t threshold_count) {
@ -87,8 +87,8 @@ bool UpdateModelKernel::Reset() {
return true;
}
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) {
if (PSContext::instance()->resetter_round() == ResetterRound::kUpdateModel) {
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) {
while (!executor_->IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
@ -96,7 +96,7 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
<< total_data_size;
if (PSContext::instance()->encrypt_type() != kPWEncryptType) {
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
FinishIteration();
}
}
@ -226,5 +226,5 @@ void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fb
REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#include <map>
#include <memory>
@ -27,7 +27,7 @@
#include "fl/server/executor.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
namespace kernel {
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
@ -44,7 +44,7 @@ class UpdateModelKernel : public RoundKernel {
bool Reset() override;
// In some cases, the last updateModel message means this server iteration is finished.
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
private:
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
@ -62,6 +62,6 @@ class UpdateModelKernel : public RoundKernel {
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_

View File

@ -17,7 +17,7 @@
#include "fl/server/local_meta_store.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
void LocalMetaStore::remove_value(const std::string &name) {
std::unique_lock<std::mutex> lock(mtx_);
@ -41,5 +41,5 @@ const size_t LocalMetaStore::curr_iter_num() {
return curr_iter_num_;
}
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
#define MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
#include <any>
#include <mutex>
@ -24,7 +24,7 @@
#include "fl/server/common.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// LocalMetaStore class is used for metadata storage of this server process.
// For example, the current iteration number, time windows for round kernels, etc.
@ -71,7 +71,7 @@ class LocalMetaStore {
const size_t curr_iter_num();
private:
LocalMetaStore() = default;
LocalMetaStore() : key_to_meta_({}), curr_iter_num_(0) {}
~LocalMetaStore() = default;
LocalMetaStore(const LocalMetaStore &) = delete;
LocalMetaStore &operator=(const LocalMetaStore &) = delete;
@ -83,6 +83,6 @@ class LocalMetaStore {
size_t curr_iter_num_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_

View File

@ -18,7 +18,7 @@
#include <utility>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) {
addresses_.try_emplace(name, address);
@ -32,5 +32,5 @@ void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64
void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); }
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
#define MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
#define MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
#include <map>
#include <string>
@ -26,7 +26,7 @@
#include "fl/server/common.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// Memory allocated in server is normally trainable parameters, hyperparameters, gradients, etc.
// MemoryRegister registers the Memory with key-value format where key refers to address's name("grad", "weights",
@ -88,6 +88,6 @@ class MemoryRegister {
}
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_

View File

@ -21,7 +21,7 @@
#include "fl/server/executor.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
void ModelStore::Initialize(uint32_t max_count) {
if (!Executor::GetInstance().initialized()) {
@ -155,5 +155,5 @@ size_t ModelStore::ComputeModelSize() {
return model_size;
}
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
#define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
#define MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
#include <map>
#include <memory>
@ -25,7 +25,7 @@
#include "fl/server/executor.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// The initial iteration number is 0 in server.
constexpr size_t kInitIterationNum = 0;
@ -84,6 +84,6 @@ class ModelStore {
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_

View File

@ -23,7 +23,7 @@
#include <algorithm>
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
MS_EXCEPTION_IF_NULL(cnode);
@ -199,8 +199,8 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
}
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
if (PSContext::instance()->server_mode() == kServerModeFL ||
PSContext::instance()->server_mode() == kServerModeHybrid) {
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
return true;
}
@ -321,13 +321,13 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
std::vector<std::string> aggregation_algorithm = {};
if (PSContext::instance()->server_mode() == kServerModeFL ||
PSContext::instance()->server_mode() == kServerModeHybrid) {
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
aggregation_algorithm.push_back("FedAvg");
} else if (PSContext::instance()->server_mode() == kServerModePS) {
} else if (ps::PSContext::instance()->server_mode() == ps::kServerModePS) {
aggregation_algorithm.push_back("DenseGradAccum");
} else {
MS_LOG(ERROR) << "Server doesn't support mode " << PSContext::instance()->server_mode();
MS_LOG(ERROR) << "Server doesn't support mode " << ps::PSContext::instance()->server_mode();
}
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
@ -344,5 +344,5 @@ template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::Aggregat
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
std::shared_ptr<MemoryRegister> memory_register);
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
#define MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
#define MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
#include <map>
#include <memory>
@ -28,7 +28,7 @@
#include "fl/server/kernel/optimizer_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server
// kernels.
@ -137,6 +137,6 @@ class ParameterAggregator {
std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_

View File

@ -21,7 +21,7 @@
#include "fl/server/iteration.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
class Server;
class Iteration;
@ -34,14 +34,14 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo
threshold_count_(threshold_count),
server_num_as_threshold_(server_num_as_threshold) {}
void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
// Register callback for round kernel.
communicator_->RegisterMsgCallBack(
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
name_, [&](std::shared_ptr<ps::core::MessageHandler> message) { LaunchRoundKernel(message); });
// Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
@ -106,7 +106,7 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
return;
}
void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message) {
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
if (message == nullptr) {
MS_LOG(ERROR) << "Message is nullptr.";
return;
@ -152,7 +152,7 @@ bool Round::check_timeout() const { return check_timeout_; }
size_t Round::time_window() const { return time_window_; }
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
// The timer starts only after the first count event is triggered by DistributedCountService.
if (check_timeout_) {
@ -164,7 +164,7 @@ void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &messa
return;
}
void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
// Same as the first count event, the timer must be stopped by DistributedCountService.
if (check_timeout_) {
@ -176,5 +176,5 @@ void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &messag
return;
}
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
#define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
#define MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
#include <memory>
#include <string>
@ -26,7 +26,7 @@
#include "fl/server/kernel/round/round_kernel.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// Round helps server to handle network round messages and launch round kernels. One iteration in server consists of
// multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting
@ -37,7 +37,7 @@ class Round {
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
~Round() = default;
void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
FinishIterCb finish_iteration_cb);
// Reinitialize count service and round kernel of this round after scaling operations are done.
@ -48,7 +48,7 @@ class Round {
// This method is the callback which will be set to the communicator and called after the corresponding round message
// is sent to the server.
void LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message);
void LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message);
// Round needs to be reset after each iteration is finished or its timer expires.
void Reset();
@ -60,8 +60,8 @@ class Round {
private:
// The callbacks which will be set to DistributedCounterService.
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
std::string name_;
@ -83,7 +83,7 @@ class Round {
// Whether this round uses the server number as its threshold count.
bool server_num_as_threshold_;
std::shared_ptr<core::CommunicatorBase> communicator_;
std::shared_ptr<ps::core::CommunicatorBase> communicator_;
// The round kernel for this Round.
std::shared_ptr<kernel::RoundKernel> kernel_;
@ -97,6 +97,6 @@ class Round {
FinalizeCb finalize_cb_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_ROUND_H_

View File

@ -30,17 +30,8 @@
#include "fl/server/kernel/round/round_kernel_factory.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
static std::vector<std::shared_ptr<core::CommunicatorBase>> global_worker_server_comms = {};
// This function is for the exit of server process when an interrupt signal is captured.
void SignalHandler(int signal) {
MS_LOG(INFO) << "Interrupt signal captured: " << signal;
std::for_each(global_worker_server_comms.begin(), global_worker_server_comms.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
return;
}
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
MS_EXCEPTION_IF_NULL(func_graph);
@ -76,7 +67,6 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
// InitCipher---->InitExecutor
void Server::Run() {
signal(SIGINT, SignalHandler);
std::unique_lock<std::mutex> lock(scaling_mtx_);
InitServerContext();
InitCluster();
@ -84,8 +74,8 @@ void Server::Run() {
RegisterCommCallbacks();
StartCommunicator();
InitExecutor();
std::string encrypt_type = PSContext::instance()->encrypt_type();
if (encrypt_type != kNotEncryptType) {
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type != ps::kNotEncryptType) {
InitCipher();
MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
}
@ -96,7 +86,7 @@ void Server::Run() {
// Wait communicators to stop so the main thread is blocked.
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Join(); });
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Join(); });
communicator_with_server_->Join();
MsException::Instance().CheckException();
return;
@ -115,18 +105,18 @@ void Server::CancelSafeMode() {
bool Server::IsSafeMode() { return safemode_.load(); }
void Server::InitServerContext() {
PSContext::instance()->GenerateResetterRound();
scheduler_ip_ = PSContext::instance()->scheduler_host();
scheduler_port_ = PSContext::instance()->scheduler_port();
worker_num_ = PSContext::instance()->initial_worker_num();
server_num_ = PSContext::instance()->initial_server_num();
ps::PSContext::instance()->GenerateResetterRound();
scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
scheduler_port_ = ps::PSContext::instance()->scheduler_port();
worker_num_ = ps::PSContext::instance()->initial_worker_num();
server_num_ = ps::PSContext::instance()->initial_server_num();
return;
}
void Server::InitCluster() {
server_node_ = std::make_shared<core::ServerNode>();
server_node_ = std::make_shared<ps::core::ServerNode>();
MS_EXCEPTION_IF_NULL(server_node_);
task_executor_ = std::make_shared<core::TaskExecutor>(32);
task_executor_ = std::make_shared<ps::core::TaskExecutor>(32);
MS_EXCEPTION_IF_NULL(task_executor_);
if (!InitCommunicatorWithServer()) {
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
@ -136,7 +126,6 @@ void Server::InitCluster() {
MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
return;
}
global_worker_server_comms = communicators_with_worker_;
return;
}
@ -187,8 +176,8 @@ void Server::InitIteration() {
}
#ifdef ENABLE_ARMOUR
std::string encrypt_type = PSContext::instance()->encrypt_type();
if (encrypt_type == kPWEncryptType) {
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) {
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
@ -245,10 +234,10 @@ void Server::InitCipher() {
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
int cipher_g = 1;
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
float dp_eps = PSContext::instance()->dp_eps();
float dp_delta = PSContext::instance()->dp_delta();
float dp_norm_clip = PSContext::instance()->dp_norm_clip();
std::string encrypt_type = PSContext::instance()->encrypt_type();
float dp_eps = ps::PSContext::instance()->dp_eps();
float dp_delta = ps::PSContext::instance()->dp_delta();
float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
mpz_t prim;
mpz_init(prim);
@ -276,7 +265,7 @@ void Server::RegisterCommCallbacks() {
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
// rounds' callbacks.
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
MS_EXCEPTION_IF_NULL(tcp_comm);
// Set message callbacks for server-to-server communication.
@ -304,23 +293,23 @@ void Server::RegisterCommCallbacks() {
std::bind(&Server::ProcessAfterScalingIn, this));
}
void Server::RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator);
communicator->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
safemode_ = true;
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
});
communicator->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [&]() {
communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
MS_LOG(ERROR)
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
"network building phase.";
safemode_ = true;
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
});
}
@ -377,7 +366,7 @@ void Server::StartCommunicator() {
MS_LOG(INFO) << "Start communicator with worker.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Start(); });
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Start(); });
}
void Server::ProcessBeforeScalingOut() {
@ -424,7 +413,7 @@ void Server::ProcessAfterScalingIn() {
if (server_node_->rank_id() == UINT32_MAX) {
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
communicator_with_server_->Stop();
return;
}
@ -449,5 +438,5 @@ void Server::ProcessAfterScalingIn() {
safemode_ = false;
}
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
#define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
#ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
#define MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
#include <memory>
#include <string>
@ -31,7 +31,7 @@
#endif
namespace mindspore {
namespace ps {
namespace fl {
namespace server {
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
class Server {
@ -90,7 +90,7 @@ class Server {
void RegisterCommCallbacks();
// Register cluster exception callbacks. This method is called in RegisterCommCallbacks.
void RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
void RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
// Initialize executor according to the server mode.
void InitExecutor();
@ -113,11 +113,11 @@ class Server {
void ProcessAfterScalingIn();
// The server node is initialized in Server.
std::shared_ptr<core::ServerNode> server_node_;
std::shared_ptr<ps::core::ServerNode> server_node_;
// The task executor of the communicators. This helps server to handle network message concurrently. The tasks
// submitted to this task executor is asynchronous.
std::shared_ptr<core::TaskExecutor> task_executor_;
std::shared_ptr<ps::core::TaskExecutor> task_executor_;
// Which protocol should communicators use.
bool use_tcp_;
@ -136,12 +136,12 @@ class Server {
// Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
// operations, etc.
std::shared_ptr<core::CommunicatorBase> communicator_with_server_;
std::shared_ptr<ps::core::CommunicatorBase> communicator_with_server_;
// The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
// In some cases, both types should be supported in one distributed training job. So here we may have multiple
// communicators.
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
// Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
std::mutex scaling_mtx_;
@ -176,6 +176,6 @@ class Server {
float percent_for_get_model_;
};
} // namespace server
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
#endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_H_

View File

@ -22,27 +22,27 @@
#include "utils/ms_exception.h"
namespace mindspore {
namespace ps {
namespace fl {
namespace worker {
void FLWorker::Run() {
worker_num_ = PSContext::instance()->worker_num();
server_num_ = PSContext::instance()->server_num();
scheduler_ip_ = PSContext::instance()->scheduler_ip();
scheduler_port_ = PSContext::instance()->scheduler_port();
worker_step_num_per_iteration_ = PSContext::instance()->worker_step_num_per_iteration();
PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_;
PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
PSContext::instance()->cluster_config().initial_worker_num = worker_num_;
PSContext::instance()->cluster_config().initial_server_num = server_num_;
worker_num_ = ps::PSContext::instance()->worker_num();
server_num_ = ps::PSContext::instance()->server_num();
scheduler_ip_ = ps::PSContext::instance()->scheduler_ip();
scheduler_port_ = ps::PSContext::instance()->scheduler_port();
worker_step_num_per_iteration_ = ps::PSContext::instance()->worker_step_num_per_iteration();
ps::PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_;
ps::PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
ps::PSContext::instance()->cluster_config().initial_worker_num = worker_num_;
ps::PSContext::instance()->cluster_config().initial_server_num = server_num_;
MS_LOG(INFO) << "Initialize cluster config for worker. Worker number:" << worker_num_
<< ", Server number:" << server_num_ << ", Scheduler ip:" << scheduler_ip_
<< ", Scheduler port:" << scheduler_port_
<< ", Worker training step per iteration:" << worker_step_num_per_iteration_;
worker_node_ = std::make_shared<core::WorkerNode>();
worker_node_ = std::make_shared<ps::core::WorkerNode>();
MS_EXCEPTION_IF_NULL(worker_node_);
worker_node_->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
Finalize();
try {
MS_LOG(EXCEPTION)
@ -51,7 +51,7 @@ void FLWorker::Run() {
MsException::Instance().SetException();
}
});
worker_node_->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() {
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
Finalize();
try {
MS_LOG(EXCEPTION)
@ -74,7 +74,7 @@ void FLWorker::Finalize() {
worker_node_->Stop();
}
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, core::TcpUserCommand command,
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output) {
// If the worker is in safemode, do not communicate with server.
while (safemode_.load()) {
@ -97,7 +97,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
if (output != nullptr) {
while (true) {
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output)) {
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
output)) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
return false;
}
@ -106,7 +107,7 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
return false;
}
if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == kClusterSafeMode) {
if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == ps::kClusterSafeMode) {
MS_LOG(INFO) << "The server " << server_rank << " is in safemode.";
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode));
} else {
@ -114,7 +115,7 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
}
}
} else {
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
return false;
}
@ -155,9 +156,9 @@ void FLWorker::InitializeFollowerScaler() {
std::bind(&FLWorker::ProcessAfterScalingOut, this));
worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
std::bind(&FLWorker::ProcessAfterScalingIn, this));
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning),
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
std::bind(&FLWorker::HandleIterationRunningEvent, this));
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted),
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
std::bind(&FLWorker::HandleIterationCompletedEvent, this));
}
@ -222,5 +223,5 @@ void FLWorker::ProcessAfterScalingIn() {
safemode_ = false;
}
} // namespace worker
} // namespace ps
} // namespace fl
} // namespace mindspore

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_
#define MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_
#ifndef MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
#define MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
#include <memory>
#include <string>
@ -28,7 +28,7 @@
#include "ps/core/communicator/tcp_communicator.h"
namespace mindspore {
namespace ps {
namespace fl {
using FBBuilder = flatbuffers::FlatBufferBuilder;
// The step number for worker to judge whether to communicate with server.
@ -59,7 +59,7 @@ class FLWorker {
}
void Run();
void Finalize();
bool SendToServer(uint32_t server_rank, const void *data, size_t size, core::TcpUserCommand command,
bool SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
std::shared_ptr<std::vector<unsigned char>> *output = nullptr);
uint32_t server_num() const;
@ -104,7 +104,7 @@ class FLWorker {
uint32_t worker_num_;
std::string scheduler_ip_;
uint16_t scheduler_port_;
std::shared_ptr<core::WorkerNode> worker_node_;
std::shared_ptr<ps::core::WorkerNode> worker_node_;
// The worker standalone training step number before communicating with server. This used in hybrid training mode.
uint64_t worker_step_num_per_iteration_;
@ -121,6 +121,6 @@ class FLWorker {
std::atomic_bool safemode_;
};
} // namespace worker
} // namespace ps
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_
#endif // MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_

View File

@ -639,7 +639,7 @@ bool StartPSWorkerAction(const ResourcePtr &res) {
return true;
}
bool StartFLWorkerAction(const ResourcePtr &) {
ps::worker::FLWorker::GetInstance().Run();
fl::worker::FLWorker::GetInstance().Run();
return true;
}
@ -665,7 +665,7 @@ bool StartServerAction(const ResourcePtr &res) {
uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
std::vector<ps::server::RoundConfig> rounds_config = {
std::vector<fl::server::RoundConfig> rounds_config = {
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
{"updateModel", true, update_model_time_window, true, update_model_threshold},
{"getModel"},
@ -676,22 +676,22 @@ bool StartServerAction(const ResourcePtr &res) {
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
ps::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
executor_threshold = update_model_threshold;
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
fl::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
executor_threshold);
} else if (server_mode_ == ps::kServerModePS) {
executor_threshold = worker_num;
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
fl::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
executor_threshold);
} else {
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
return false;
}
ps::server::Server::GetInstance().Run();
fl::server::Server::GetInstance().Run();
return true;
}

View File

@ -1293,7 +1293,7 @@ void ClearResAtexit() {
MS_LOG(INFO) << "Start finalizing worker.";
const std::string &server_mode = ps::PSContext::instance()->server_mode();
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid)) {
ps::worker::FLWorker::GetInstance().Finalize();
fl::worker::FLWorker::GetInstance().Finalize();
} else {
ps::Worker::GetInstance().Finalize();
}

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
syntax = "proto3";
package mindspore.ps;
package mindspore.fl;
message CollectiveData {
bytes data = 1;

View File

@ -286,8 +286,9 @@ void PSContext::GenerateResetterRound() {
return;
}
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3);
binary_server_context = ((unsigned int)is_parameter_server_mode << 0) |
((unsigned int)is_federated_learning_mode << 1) |
((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)secure_aggregation_ << 3);
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
resetter_round_ = ResetterRound::kNoNeedToReset;
} else {