!19760 Fix fl namespace issue.
Merge pull request !19760 from ZPaC/1.3-change-dir
This commit is contained in:
commit
fd18b146bb
|
@ -47,13 +47,13 @@ class FusedPullWeightKernel : public CPUKernel {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ps::FBBuilder> fbb = std::make_shared<ps::FBBuilder>();
|
||||
std::shared_ptr<fl::FBBuilder> fbb = std::make_shared<fl::FBBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(fbb);
|
||||
|
||||
total_iteration_++;
|
||||
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
||||
if (total_iteration_ % ps::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
||||
ps::kTrainBeginStepNum) {
|
||||
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
||||
fl::kTrainBeginStepNum) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -72,10 +72,10 @@ class FusedPullWeightKernel : public CPUKernel {
|
|||
const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
|
||||
int retcode = schema::ResponseCode_SucNotReady;
|
||||
while (retcode == schema::ResponseCode_SucNotReady) {
|
||||
if (!ps::worker::FLWorker::GetInstance().SendToServer(
|
||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(
|
||||
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
|
||||
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. This iteration is dropped.";
|
||||
ps::worker::FLWorker::GetInstance().SetIterationRunning();
|
||||
fl::worker::FLWorker::GetInstance().SetIterationRunning();
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
|
||||
|
@ -116,7 +116,7 @@ class FusedPullWeightKernel : public CPUKernel {
|
|||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
||||
ps::worker::FLWorker::GetInstance().SetIterationRunning();
|
||||
fl::worker::FLWorker::GetInstance().SetIterationRunning();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -154,7 +154,7 @@ class FusedPullWeightKernel : public CPUKernel {
|
|||
void InitSizeLists() { return; }
|
||||
|
||||
private:
|
||||
bool BuildPullWeightReq(std::shared_ptr<ps::FBBuilder> fbb) {
|
||||
bool BuildPullWeightReq(std::shared_ptr<fl::FBBuilder> fbb) {
|
||||
MS_EXCEPTION_IF_NULL(fbb);
|
||||
std::vector<flatbuffers::Offset<flatbuffers::String>> fbs_weight_names;
|
||||
for (const std::string &weight_name : weight_full_names_) {
|
||||
|
|
|
@ -45,13 +45,13 @@ class FusedPushWeightKernel : public CPUKernel {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ps::FBBuilder> fbb = std::make_shared<ps::FBBuilder>();
|
||||
std::shared_ptr<fl::FBBuilder> fbb = std::make_shared<fl::FBBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(fbb);
|
||||
|
||||
total_iteration_++;
|
||||
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
||||
if (total_iteration_ % ps::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
||||
ps::kTrainBeginStepNum) {
|
||||
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
||||
fl::kTrainBeginStepNum) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -67,17 +67,17 @@ class FusedPushWeightKernel : public CPUKernel {
|
|||
}
|
||||
|
||||
// The server number may change after scaling in/out.
|
||||
for (uint32_t i = 0; i < ps::worker::FLWorker::GetInstance().server_num(); i++) {
|
||||
for (uint32_t i = 0; i < fl::worker::FLWorker::GetInstance().server_num(); i++) {
|
||||
std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr;
|
||||
const schema::ResponsePushWeight *push_weight_rsp = nullptr;
|
||||
int retcode = schema::ResponseCode_SucNotReady;
|
||||
while (retcode == schema::ResponseCode_SucNotReady) {
|
||||
if (!ps::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
|
||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
|
||||
ps::core::TcpUserCommand::kPushWeight,
|
||||
&push_weight_rsp_msg)) {
|
||||
MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i
|
||||
<< " failed. This iteration is dropped.";
|
||||
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
|
||||
|
@ -105,7 +105,7 @@ class FusedPushWeightKernel : public CPUKernel {
|
|||
}
|
||||
|
||||
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
||||
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -143,7 +143,7 @@ class FusedPushWeightKernel : public CPUKernel {
|
|||
void InitSizeLists() { return; }
|
||||
|
||||
private:
|
||||
bool BuildPushWeightReq(std::shared_ptr<ps::FBBuilder> fbb, const std::vector<AddressPtr> &weights) {
|
||||
bool BuildPushWeightReq(std::shared_ptr<fl::FBBuilder> fbb, const std::vector<AddressPtr> &weights) {
|
||||
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||
for (size_t i = 0; i < weight_full_names_.size(); i++) {
|
||||
const std::string &weight_name = weight_full_names_[i];
|
||||
|
|
|
@ -31,8 +31,8 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
|||
int return_num = 0;
|
||||
cipher_meta_storage_.RegisterClass();
|
||||
const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN);
|
||||
cipher_meta_storage_.RegisterPrime(ps::server::kCtxCipherPrimer, new_prime);
|
||||
if (!cipher_meta_storage_.GetPrimeFromServer(ps::server::kCtxCipherPrimer, publicparam_.prime)) {
|
||||
cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime);
|
||||
if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) {
|
||||
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
|||
publicparam_.t = param.t;
|
||||
secrets_minnums_ = param.t;
|
||||
client_num_need_ = cipher_initial_client_cnt;
|
||||
featuremap_ = ps::server::ModelStore::GetInstance().model_size() / sizeof(float);
|
||||
featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float);
|
||||
share_clients_num_need_ = cipher_share_secrets_cnt;
|
||||
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
|
||||
get_model_num_need_ = cipher_get_clientlist_cnt;
|
||||
|
|
|
@ -21,7 +21,7 @@ namespace mindspore {
|
|||
namespace armour {
|
||||
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::GetExchangeKeys *get_exchange_keys_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> get_exchange_keys_resp_builder) {
|
||||
std::shared_ptr<fl::server::FBBuilder> get_exchange_keys_resp_builder) {
|
||||
MS_LOG(INFO) << "CipherMgr::GetKeys START";
|
||||
if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) {
|
||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
|
||||
|
@ -32,7 +32,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
|
|||
// get clientlist from memory server.
|
||||
std::vector<std::string> clients;
|
||||
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList, &clients);
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &clients);
|
||||
|
||||
size_t cur_clients_num = clients.size();
|
||||
std::string fl_id = get_exchange_keys_req->fl_id()->str();
|
||||
|
@ -61,7 +61,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
|
|||
|
||||
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder) {
|
||||
std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder) {
|
||||
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
|
||||
// step 0: judge if the input param is legal.
|
||||
if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) {
|
||||
|
@ -75,8 +75,8 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
|
|||
// step 1: get clientlist and client keys from memory server.
|
||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
||||
std::vector<std::string> client_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList, &client_list);
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys);
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||
|
||||
// step2: process new item data. and update new item data to memory server.
|
||||
size_t cur_clients_num = client_list.size();
|
||||
|
@ -131,9 +131,9 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
|
|||
cur_public_key.push_back(spk);
|
||||
|
||||
bool retcode_key =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(ps::server::kCtxClientsKeys, fl_id, cur_public_key);
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, fl_id, cur_public_key);
|
||||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxExChangeKeysClientList, fl_id);
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
|
||||
if (retcode_key && retcode_client) {
|
||||
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
|
@ -147,7 +147,7 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
|
|||
}
|
||||
}
|
||||
|
||||
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder,
|
||||
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder,
|
||||
const schema::ResponseCode retcode, const std::string &reason,
|
||||
const std::string &next_req_time, const int iteration) {
|
||||
auto rsp_reason = exchange_keys_resp_builder->CreateString(reason);
|
||||
|
@ -162,7 +162,7 @@ void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exc
|
|||
return;
|
||||
}
|
||||
|
||||
bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
bool CipherKeys::BuildGetKeys(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const int iteration, const std::string &next_req_time, bool is_good) {
|
||||
bool flag = true;
|
||||
if (is_good) {
|
||||
|
@ -170,7 +170,7 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
|
|||
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
|
||||
MS_LOG(INFO) << "Get Keys: ";
|
||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys);
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||
if (record_public_keys.size() < cipher_init_->client_num_need_) {
|
||||
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
|
||||
<< "clients num: " << cipher_init_->client_num_need_;
|
||||
|
@ -221,8 +221,8 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
|
|||
}
|
||||
|
||||
void CipherKeys::ClearKeys() {
|
||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxExChangeKeysClientList);
|
||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsKeys);
|
||||
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxExChangeKeysClientList);
|
||||
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsKeys);
|
||||
}
|
||||
|
||||
} // namespace armour
|
||||
|
|
|
@ -45,18 +45,18 @@ class CipherKeys {
|
|||
// handle the client's request of get keys.
|
||||
bool GetKeys(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::GetExchangeKeys *get_exchange_keys_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> get_exchange_keys_resp_builder);
|
||||
std::shared_ptr<fl::server::FBBuilder> get_exchange_keys_resp_builder);
|
||||
|
||||
// handle the client's request of exchange keys.
|
||||
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder);
|
||||
std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder);
|
||||
|
||||
// build response code of get keys.
|
||||
bool BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode, const int iteration,
|
||||
bool BuildGetKeys(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode, const int iteration,
|
||||
const std::string &next_req_time, bool is_good);
|
||||
// build response code of exchange keys.
|
||||
void BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder,
|
||||
void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder,
|
||||
const schema::ResponseCode retcode, const std::string &reason,
|
||||
const std::string &next_req_time, const int iteration);
|
||||
// clear the shared memory.
|
||||
|
|
|
@ -21,16 +21,16 @@ namespace armour {
|
|||
|
||||
void CipherMetaStorage::GetClientSharesFromServer(
|
||||
const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) {
|
||||
const ps::PBMetadata &clients_shares_pb_out =
|
||||
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const ps::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
|
||||
const fl::PBMetadata &clients_shares_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const fl::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
|
||||
auto iter = clients_shares_pb.client_secret_shares().begin();
|
||||
for (; iter != clients_shares_pb.client_secret_shares().end(); ++iter) {
|
||||
std::string fl_id = iter->first;
|
||||
const ps::SharesPb &shares_pb = iter->second;
|
||||
const fl::SharesPb &shares_pb = iter->second;
|
||||
std::vector<clientshare_str> encrpted_shares_new;
|
||||
for (int index_shares = 0; index_shares < shares_pb.clientsharestrs_size(); ++index_shares) {
|
||||
const ps::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
|
||||
const fl::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
|
||||
clientshare_str new_clientshare;
|
||||
new_clientshare.fl_id = client_share_str_pb.fl_id();
|
||||
new_clientshare.index = client_share_str_pb.index();
|
||||
|
@ -42,8 +42,8 @@ void CipherMetaStorage::GetClientSharesFromServer(
|
|||
}
|
||||
|
||||
void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) {
|
||||
const ps::PBMetadata &client_list_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const ps::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
||||
const fl::PBMetadata &client_list_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const fl::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
||||
std::string fl_id = client_list_pb.fl_id(i);
|
||||
clients_list->push_back(fl_id);
|
||||
|
@ -52,14 +52,14 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect
|
|||
|
||||
void CipherMetaStorage::GetClientKeysFromServer(
|
||||
const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) {
|
||||
const ps::PBMetadata &clients_keys_pb_out =
|
||||
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const ps::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||
const fl::PBMetadata &clients_keys_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||
|
||||
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
|
||||
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
|
||||
std::string fl_id = iter->first;
|
||||
ps::KeysPb keys_pb = iter->second;
|
||||
fl::KeysPb keys_pb = iter->second;
|
||||
std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||
std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
||||
std::vector<std::vector<unsigned char>> cur_keys;
|
||||
|
@ -70,9 +70,9 @@ void CipherMetaStorage::GetClientKeysFromServer(
|
|||
}
|
||||
|
||||
bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) {
|
||||
const ps::PBMetadata &clients_noises_pb_out =
|
||||
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const ps::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||
const fl::PBMetadata &clients_noises_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||
while (clients_noises_pb.has_one_client_noises() == false) {
|
||||
MS_LOG(INFO) << "GetClientNoisesFromServer NULL.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
|
@ -83,8 +83,8 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
|
|||
}
|
||||
|
||||
bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) {
|
||||
const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
|
||||
ps::Prime prime_pb(prime_pb_out.prime());
|
||||
const fl::PBMetadata &prime_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
|
||||
fl::Prime prime_pb(prime_pb_out.prime());
|
||||
std::string str = *(prime_pb.mutable_prime());
|
||||
MS_LOG(INFO) << "get prime from metastorage :" << str;
|
||||
|
||||
|
@ -99,20 +99,20 @@ bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char
|
|||
|
||||
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
|
||||
bool retcode = true;
|
||||
ps::FLId fl_id_pb;
|
||||
fl::FLId fl_id_pb;
|
||||
fl_id_pb.set_fl_id(fl_id);
|
||||
ps::PBMetadata client_pb;
|
||||
fl::PBMetadata client_pb;
|
||||
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
|
||||
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
|
||||
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
|
||||
return retcode;
|
||||
}
|
||||
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
|
||||
MS_LOG(INFO) << "register prime: " << prime;
|
||||
ps::Prime prime_id_pb;
|
||||
fl::Prime prime_id_pb;
|
||||
prime_id_pb.set_prime(prime);
|
||||
ps::PBMetadata prime_pb;
|
||||
fl::PBMetadata prime_pb;
|
||||
prime_pb.mutable_prime()->MergeFrom(prime_id_pb);
|
||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
||||
|
@ -123,25 +123,25 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std
|
|||
return false;
|
||||
}
|
||||
// update new item to memory server.
|
||||
ps::KeysPb keys;
|
||||
fl::KeysPb keys;
|
||||
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
|
||||
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
|
||||
ps::PairClientKeys pair_client_keys_pb;
|
||||
fl::PairClientKeys pair_client_keys_pb;
|
||||
pair_client_keys_pb.set_fl_id(fl_id);
|
||||
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
||||
ps::PBMetadata client_and_keys_pb;
|
||||
fl::PBMetadata client_and_keys_pb;
|
||||
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
||||
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||
return retcode;
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) {
|
||||
// update new item to memory server.
|
||||
ps::OneClientNoises noises_pb;
|
||||
fl::OneClientNoises noises_pb;
|
||||
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
|
||||
ps::PBMetadata client_noises_pb;
|
||||
fl::PBMetadata client_noises_pb;
|
||||
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
|
||||
return ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
|
||||
return fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientShareToServer(
|
||||
|
@ -149,10 +149,10 @@ bool CipherMetaStorage::UpdateClientShareToServer(
|
|||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
|
||||
bool retcode = true;
|
||||
int size_shares = shares->size();
|
||||
ps::SharesPb shares_pb;
|
||||
fl::SharesPb shares_pb;
|
||||
for (int index = 0; index < size_shares; ++index) {
|
||||
// new item
|
||||
ps::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
|
||||
fl::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
|
||||
std::string fl_id_new = (*shares)[index]->fl_id()->str();
|
||||
int index_new = (*shares)[index]->index();
|
||||
auto share = (*shares)[index]->share();
|
||||
|
@ -160,32 +160,32 @@ bool CipherMetaStorage::UpdateClientShareToServer(
|
|||
client_share_str_new_p->set_fl_id(fl_id_new);
|
||||
client_share_str_new_p->set_index(index_new);
|
||||
}
|
||||
ps::PairClientShares pair_client_shares_pb;
|
||||
fl::PairClientShares pair_client_shares_pb;
|
||||
pair_client_shares_pb.set_fl_id(fl_id);
|
||||
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
|
||||
ps::PBMetadata client_and_shares_pb;
|
||||
fl::PBMetadata client_and_shares_pb;
|
||||
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
|
||||
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
|
||||
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
|
||||
return retcode;
|
||||
}
|
||||
|
||||
void CipherMetaStorage::RegisterClass() {
|
||||
ps::PBMetadata exchange_kyes_client_list;
|
||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxExChangeKeysClientList,
|
||||
fl::PBMetadata exchange_kyes_client_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
|
||||
exchange_kyes_client_list);
|
||||
ps::PBMetadata clients_keys;
|
||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsKeys, clients_keys);
|
||||
ps::PBMetadata reconstruct_client_list;
|
||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxReconstructClientList,
|
||||
fl::PBMetadata clients_keys;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
|
||||
fl::PBMetadata reconstruct_client_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxReconstructClientList,
|
||||
reconstruct_client_list);
|
||||
ps::PBMetadata clients_reconstruct_shares;
|
||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsReconstructShares,
|
||||
fl::PBMetadata clients_reconstruct_shares;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsReconstructShares,
|
||||
clients_reconstruct_shares);
|
||||
ps::PBMetadata share_secretes_client_list;
|
||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxShareSecretsClientList,
|
||||
fl::PBMetadata share_secretes_client_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
|
||||
share_secretes_client_list);
|
||||
ps::PBMetadata clients_encrypt_shares;
|
||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsEncryptedShares,
|
||||
fl::PBMetadata clients_encrypt_shares;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
|
||||
clients_encrypt_shares);
|
||||
}
|
||||
} // namespace armour
|
||||
|
|
|
@ -101,15 +101,15 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
|
|||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
|
||||
bool retcode = true;
|
||||
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori;
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsReconstructShares,
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
||||
&reconstruct_secret_list_ori);
|
||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys);
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||
std::vector<std::string> clients_reconstruct_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxReconstructClientList,
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
|
||||
&clients_reconstruct_list);
|
||||
std::vector<std::string> clients_share_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList,
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||
&clients_share_list);
|
||||
if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() ||
|
||||
record_public_keys.size() < cipher_init_->client_num_need_ ||
|
||||
|
@ -146,7 +146,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
|
|||
client_keys.clear();
|
||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
|
||||
|
||||
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(ps::server::kCtxClientNoises, noise) == false)
|
||||
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise) == false)
|
||||
return false;
|
||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
|
||||
} else {
|
||||
|
@ -159,7 +159,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
|
|||
// reconstruct secrets
|
||||
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> reconstruct_secret_resp_builder,
|
||||
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
|
||||
const std::vector<std::string> &client_list) {
|
||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
|
||||
clock_t start_time = clock();
|
||||
|
@ -178,10 +178,10 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
|
|||
return false;
|
||||
}
|
||||
std::vector<std::string> clients_reconstruct_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxReconstructClientList,
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
|
||||
&clients_reconstruct_list);
|
||||
std::map<std::string, std::vector<clientshare_str>> clients_shares_all;
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsReconstructShares,
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
||||
&clients_shares_all);
|
||||
|
||||
size_t count_client_num = clients_shares_all.size();
|
||||
|
@ -215,9 +215,9 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
|
|||
}
|
||||
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
|
||||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxReconstructClientList, fl_id);
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id);
|
||||
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
||||
ps::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
|
||||
fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
|
||||
if (!(retcode_client && retcode_share)) {
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
"reconstruct update shares or client failed.", cur_iterator, next_req_time);
|
||||
|
@ -233,9 +233,9 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
|
|||
return true;
|
||||
} else {
|
||||
bool retcode_result = true;
|
||||
const ps::PBMetadata &clients_noises_pb_out =
|
||||
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(ps::server::kCtxClientNoises);
|
||||
const ps::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||
const fl::PBMetadata &clients_noises_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
|
||||
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||
if (clients_noises_pb.has_one_client_noises() == false) {
|
||||
MS_LOG(INFO) << "Success,the secret will be reconstructed.";
|
||||
retcode_result = ReconstructSecretsGenNoise(client_list);
|
||||
|
@ -279,13 +279,13 @@ bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
|
|||
|
||||
void CipherReconStruct::ClearReconstructSecrets() {
|
||||
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets START";
|
||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxReconstructClientList);
|
||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsReconstructShares);
|
||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientNoises);
|
||||
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxReconstructClientList);
|
||||
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsReconstructShares);
|
||||
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientNoises);
|
||||
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success";
|
||||
}
|
||||
|
||||
void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<ps::server::FBBuilder> fbb,
|
||||
void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb,
|
||||
const schema::ResponseCode retcode, const std::string &reason,
|
||||
const int iteration, const std::string &next_req_time) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
|
|
|
@ -44,11 +44,11 @@ class CipherReconStruct {
|
|||
// reconstruct secret mask
|
||||
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> reconstruct_secret_resp_builder,
|
||||
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
|
||||
const std::vector<std::string> &client_list);
|
||||
|
||||
// build response code of reconstruct secret.
|
||||
void BuildReconstructSecretsRsp(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
void BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const int iteration, const std::string &next_req_time);
|
||||
|
||||
// clear the shared memory.
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mindspore {
|
||||
namespace armour {
|
||||
bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder,
|
||||
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
|
||||
const string next_req_time) {
|
||||
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
|
||||
if (share_secrets_req == nullptr) {
|
||||
|
@ -35,13 +35,13 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
|||
// step 1: get client list and share secrets from memory server.
|
||||
clock_t start_time = clock();
|
||||
std::vector<std::string> clients_share_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList,
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||
&clients_share_list);
|
||||
std::vector<std::string> clients_exchange_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList,
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList,
|
||||
&clients_exchange_list);
|
||||
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsEncryptedShares,
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
||||
&encrypted_shares_all);
|
||||
|
||||
MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size()
|
||||
|
@ -75,9 +75,9 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
|||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
|
||||
(share_secrets_req->encrypted_shares());
|
||||
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
||||
ps::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
|
||||
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
|
||||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxShareSecretsClientList, fl_id_src);
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src);
|
||||
bool retcode = retcode_share && retcode_client;
|
||||
if (retcode) {
|
||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
|
||||
|
@ -95,7 +95,7 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
|||
}
|
||||
|
||||
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder,
|
||||
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
|
||||
const std::string &next_req_time) {
|
||||
MS_LOG(INFO) << "CipherShares::GetSecrets START";
|
||||
clock_t start_time = clock();
|
||||
|
@ -108,10 +108,10 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
|||
|
||||
// step 1: get client list and client shares list from memory server.
|
||||
std::vector<std::string> clients_share_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList,
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||
&clients_share_list);
|
||||
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsEncryptedShares,
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
||||
&encrypted_shares_all);
|
||||
int iteration = get_secrets_req->iteration();
|
||||
size_t share_clients_num = clients_share_list.size();
|
||||
|
@ -180,7 +180,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
|||
}
|
||||
|
||||
void CipherShares::BuildGetSecretsRsp(
|
||||
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
|
||||
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
|
||||
std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) {
|
||||
int rsp_retcode = retcode;
|
||||
int rsp_iteration = iteration;
|
||||
|
@ -199,7 +199,7 @@ void CipherShares::BuildGetSecretsRsp(
|
|||
return;
|
||||
}
|
||||
|
||||
void CipherShares::BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder,
|
||||
void CipherShares::BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
|
||||
const schema::ResponseCode retcode, const string &reason,
|
||||
const string &next_req_time, const int iteration) {
|
||||
auto rsp_reason = share_secrets_resp_builder->CreateString(reason);
|
||||
|
@ -211,8 +211,8 @@ void CipherShares::BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> s
|
|||
}
|
||||
|
||||
void CipherShares::ClearShareSecrets() {
|
||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxShareSecretsClientList);
|
||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsEncryptedShares);
|
||||
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxShareSecretsClientList);
|
||||
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsEncryptedShares);
|
||||
}
|
||||
|
||||
} // namespace armour
|
||||
|
|
|
@ -43,17 +43,17 @@ class CipherShares {
|
|||
|
||||
// handle the client's request of share secrets.
|
||||
bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder, const string next_req_time);
|
||||
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, const string next_req_time);
|
||||
// handle the client's request of get secrets.
|
||||
bool GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time);
|
||||
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time);
|
||||
|
||||
// build response code of share secrets.
|
||||
void BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder,
|
||||
void BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
|
||||
const schema::ResponseCode retcode, const string &reason, const string &next_req_time,
|
||||
const int iteration);
|
||||
// build response code of get secrets.
|
||||
void BuildGetSecretsRsp(std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder,
|
||||
void BuildGetSecretsRsp(std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
|
||||
const schema::ResponseCode retcode, const int iteration, std::string next_req_time,
|
||||
std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares);
|
||||
// clear the shared memory.
|
||||
|
|
|
@ -26,13 +26,13 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
|
|||
clock_t start_time = clock();
|
||||
std::vector<float> noise;
|
||||
|
||||
cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(ps::server::kCtxClientNoises, &noise);
|
||||
cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
|
||||
if (noise.size() != cipher_init_->featuremap_) {
|
||||
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t data_size = ps::server::LocalMetaStore::GetInstance().value<size_t>(ps::server::kCtxFedAvgTotalDataSize);
|
||||
size_t data_size = fl::server::LocalMetaStore::GetInstance().value<size_t>(fl::server::kCtxFedAvgTotalDataSize);
|
||||
int sum_size = 0;
|
||||
for (auto iter = data.begin(); iter != data.end(); ++iter) {
|
||||
int size_data = iter->second->size / sizeof(float);
|
||||
|
|
|
@ -17,13 +17,13 @@
|
|||
#include "fl/server/collective_ops_impl.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
|
||||
void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
|
||||
MS_EXCEPTION_IF_NULL(server_node);
|
||||
server_node_ = server_node;
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
server_num_ = ps::PSContext::instance()->initial_server_num();
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -66,7 +66,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
// Step 1: Async send data to next rank.
|
||||
size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
||||
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||
chunk_sizes[send_chunk_index] * sizeof(T));
|
||||
// Step 2: Async receive data to next rank and wait until it's done.
|
||||
size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size;
|
||||
|
@ -76,12 +76,16 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
memcpy_s(tmp_recv_chunk.get(), chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
|
||||
ret = memcpy_s(tmp_recv_chunk.get(), chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Step 3: Reduce the data so we can overlap the time cost of send.
|
||||
for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) {
|
||||
|
@ -100,7 +104,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
for (size_t i = 0; i < rank_size - 1; i++) {
|
||||
size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size;
|
||||
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||
chunk_sizes[send_chunk_index] * sizeof(T));
|
||||
size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
||||
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
|
||||
|
@ -109,13 +113,17 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
|||
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||
|
||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
|
||||
ret = memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
if (!server_node_->Wait(send_req_id, 1)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
|
@ -143,19 +151,23 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
|||
for (uint32_t i = 1; i < rank_size; i++) {
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str);
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size());
|
||||
ret = memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
for (size_t j = 0; j < count; j++) {
|
||||
output_buff[j] += tmp_recv_buff[j];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
|
||||
if (!server_node_->Wait(send_req_id)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
|
@ -168,7 +180,8 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
|||
if (local_rank_ == 0) {
|
||||
for (uint32_t i = 1; i < rank_size; i++) {
|
||||
MS_LOG(DEBUG) << "Broadcast data to process " << i;
|
||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
|
||||
auto send_req_id =
|
||||
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
|
||||
if (!server_node_->Wait(send_req_id)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||
return false;
|
||||
|
@ -177,12 +190,16 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
|||
} else {
|
||||
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
|
||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str);
|
||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
|
||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||
return false;
|
||||
}
|
||||
memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
|
||||
ret = memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
|
||||
if (ret != 0) {
|
||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "End broadcast.";
|
||||
return true;
|
||||
|
@ -231,5 +248,5 @@ template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *re
|
|||
template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -27,7 +27,7 @@
|
|||
#include "fl/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// CollectiveOpsImpl is the collective communication API of the server.
|
||||
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
|
||||
|
@ -39,7 +39,7 @@ class CollectiveOpsImpl {
|
|||
return instance;
|
||||
}
|
||||
|
||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
|
||||
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
|
||||
|
||||
template <typename T>
|
||||
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||
|
@ -48,7 +48,7 @@ class CollectiveOpsImpl {
|
|||
bool ReInitForScaling();
|
||||
|
||||
private:
|
||||
CollectiveOpsImpl() = default;
|
||||
CollectiveOpsImpl() : server_node_(nullptr), local_rank_(0), server_num_(0) {}
|
||||
~CollectiveOpsImpl() = default;
|
||||
CollectiveOpsImpl(const CollectiveOpsImpl &) = delete;
|
||||
CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete;
|
||||
|
@ -61,7 +61,7 @@ class CollectiveOpsImpl {
|
|||
template <typename T>
|
||||
bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
uint32_t local_rank_;
|
||||
uint32_t server_num_;
|
||||
|
||||
|
@ -69,6 +69,6 @@ class CollectiveOpsImpl {
|
|||
std::mutex mtx_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
@ -37,7 +37,7 @@
|
|||
#include "ps/core/communicator/message_handler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// Definitions for the server framework.
|
||||
enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
|
||||
|
@ -73,7 +73,7 @@ using TimeOutCb = std::function<void(bool, const std::string &)>;
|
|||
using StopTimerCb = std::function<void(void)>;
|
||||
using FinishIterCb = std::function<void(bool, const std::string &)>;
|
||||
using FinalizeCb = std::function<void(void)>;
|
||||
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
|
||||
using MessageCallback = std::function<void(const std::shared_ptr<ps::core::MessageHandler> &)>;
|
||||
|
||||
// Information about whether server kernel will reuse kernel node memory from the front end.
|
||||
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
|
||||
|
@ -237,6 +237,6 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size
|
|||
// Definitions for Parameter Server.
|
||||
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "fl/server/consistent_hash_ring.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
bool ConsistentHashRing::Insert(uint32_t rank) {
|
||||
for (uint32_t i = 0; i < virtual_node_num_; i++) {
|
||||
|
@ -53,5 +53,5 @@ uint32_t ConsistentHashRing::Find(const std::string &key) {
|
|||
return iterator->second;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,15 +14,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in
|
||||
// server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node.
|
||||
|
@ -59,6 +59,6 @@ class ConsistentHashRing {
|
|||
std::map<size_t, uint32_t> ring_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
|
||||
|
|
|
@ -20,19 +20,19 @@
|
|||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node,
|
||||
void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node,
|
||||
uint32_t counting_server_rank) {
|
||||
server_node_ = server_node;
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
server_num_ = ps::PSContext::instance()->initial_server_num();
|
||||
counting_server_rank_ = counting_server_rank;
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||
communicator_ = communicator;
|
||||
MS_EXCEPTION_IF_NULL(communicator_);
|
||||
communicator_->RegisterMsgCallBack(
|
||||
|
@ -94,14 +94,14 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
|
|||
report_count_req.set_id(id);
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount,
|
||||
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
|
||||
&report_cnt_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
|
||||
return false;
|
||||
}
|
||||
|
||||
CountResponse count_rsp;
|
||||
count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), report_cnt_rsp_msg->size());
|
||||
count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), SizeToInt(report_cnt_rsp_msg->size()));
|
||||
if (!count_rsp.result()) {
|
||||
MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason();
|
||||
return false;
|
||||
|
@ -126,13 +126,14 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
|||
|
||||
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
|
||||
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
|
||||
ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
|
||||
return false;
|
||||
}
|
||||
|
||||
CountReachThresholdResponse count_reach_threshold_rsp;
|
||||
count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size());
|
||||
count_reach_threshold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(),
|
||||
SizeToInt(query_cnt_enough_rsp_msg->size()));
|
||||
return count_reach_threshold_rsp.is_enough();
|
||||
}
|
||||
}
|
||||
|
@ -164,14 +165,14 @@ bool DistributedCountService::ReInitForScaling() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
CountRequest report_count_req;
|
||||
report_count_req.ParseFromArray(message->data(), message->len());
|
||||
report_count_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const std::string &name = report_count_req.name();
|
||||
const std::string &id = report_count_req.id();
|
||||
|
||||
|
@ -213,14 +214,15 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::Mes
|
|||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void DistributedCountService::HandleCountReachThresholdRequest(
|
||||
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
CountReachThresholdRequest count_reach_threshold_req;
|
||||
count_reach_threshold_req.ParseFromArray(message->data(), message->len());
|
||||
count_reach_threshold_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const std::string &name = count_reach_threshold_req.name();
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_[name]);
|
||||
|
@ -236,7 +238,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
|
|||
return;
|
||||
}
|
||||
|
||||
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
|
@ -248,7 +250,7 @@ void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::Mes
|
|||
communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message);
|
||||
|
||||
CounterEvent counter_event;
|
||||
counter_event.ParseFromArray(message->data(), message->len());
|
||||
counter_event.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const auto &type = counter_event.type();
|
||||
const auto &name = counter_event.name();
|
||||
|
||||
|
@ -289,7 +291,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
|||
|
||||
// Broadcast to all follower servers.
|
||||
for (uint32_t i = 1; i < server_num_; i++) {
|
||||
if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
||||
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
||||
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -307,7 +309,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
|||
|
||||
// Broadcast to all follower servers.
|
||||
for (uint32_t i = 1; i < server_num_; i++) {
|
||||
if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
||||
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
||||
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -317,5 +319,5 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
|||
return true;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
@ -27,7 +27,7 @@
|
|||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
constexpr uint32_t kDefaultCountingServerRank = 0;
|
||||
constexpr auto kModuleDistributedCountService = "DistributedCountService";
|
||||
|
@ -54,10 +54,10 @@ class DistributedCountService {
|
|||
}
|
||||
|
||||
// Initialize counter service with the server node because communication is needed.
|
||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank);
|
||||
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node, uint32_t counting_server_rank);
|
||||
|
||||
// Register message callbacks of the counting server to handle messages sent by the other servers.
|
||||
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
||||
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
||||
|
||||
// Register counter to the counting server for the name with its threshold count in server cluster dimension and
|
||||
// first/last count event callbacks.
|
||||
|
@ -87,15 +87,15 @@ class DistributedCountService {
|
|||
DistributedCountService &operator=(const DistributedCountService &) = delete;
|
||||
|
||||
// Callback for the reporting count message from other servers. Only counting server will call this method.
|
||||
void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Callback for the querying whether threshold count is reached message from other servers. Only counting
|
||||
// server will call this method.
|
||||
void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandleCountReachThresholdRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Callback for the first/last event message from the counting server. Only other servers will call this
|
||||
// method.
|
||||
void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Call the callbacks when the first/last count event is triggered.
|
||||
bool TriggerCounterEvent(const std::string &name);
|
||||
|
@ -103,8 +103,8 @@ class DistributedCountService {
|
|||
bool TriggerLastCountEvent(const std::string &name);
|
||||
|
||||
// Members for the communication between counting server and other servers.
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
std::shared_ptr<core::TcpCommunicator> communicator_;
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
|
||||
uint32_t local_rank_;
|
||||
uint32_t server_num_;
|
||||
|
||||
|
@ -126,6 +126,6 @@ class DistributedCountService {
|
|||
std::unordered_map<std::string, std::mutex> mutex_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||
|
|
|
@ -20,18 +20,18 @@
|
|||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
|
||||
void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
|
||||
server_node_ = server_node;
|
||||
MS_EXCEPTION_IF_NULL(server_node);
|
||||
local_rank_ = server_node_->rank_id();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
server_num_ = ps::PSContext::instance()->initial_server_num();
|
||||
InitHashRing();
|
||||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||
communicator_ = communicator;
|
||||
MS_EXCEPTION_IF_NULL(communicator_);
|
||||
communicator_->RegisterMsgCallBack(
|
||||
|
@ -100,7 +100,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
|
|||
metadata_with_name.set_name(name);
|
||||
*metadata_with_name.mutable_metadata() = meta;
|
||||
std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata,
|
||||
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
|
||||
&update_meta_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
|
||||
return false;
|
||||
|
@ -133,12 +133,12 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
|||
PBMetadata get_metadata_rsp;
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata,
|
||||
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetMetadata,
|
||||
&get_meta_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
|
||||
return get_metadata_rsp;
|
||||
}
|
||||
get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), get_meta_rsp_msg->size());
|
||||
get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), SizeToInt(get_meta_rsp_msg->size()));
|
||||
return get_metadata_rsp;
|
||||
}
|
||||
}
|
||||
|
@ -174,14 +174,14 @@ void DistributedMetadataStore::InitHashRing() {
|
|||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
}
|
||||
|
||||
PBMetadataWithName meta_with_name;
|
||||
meta_with_name.ParseFromArray(message->data(), message->len());
|
||||
meta_with_name.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||
const std::string &name = meta_with_name.name();
|
||||
MS_LOG(INFO) << "Update metadata for " << name;
|
||||
|
||||
|
@ -196,7 +196,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
|
|||
return;
|
||||
}
|
||||
|
||||
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
|
@ -267,7 +267,7 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
|
|||
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
|
||||
auto &fl_id = meta.pair_client_shares().fl_id();
|
||||
auto &client_shares = meta.pair_client_shares().client_shares();
|
||||
// google::protobuf::Map< std::string, mindspore::ps::core::SharesPb >::const_iterator iter;
|
||||
// google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
|
||||
// Check whether the new item already exists.
|
||||
bool add_flag = true;
|
||||
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); iter++) {
|
||||
|
@ -299,5 +299,5 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
|
|||
return true;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
@ -27,7 +27,7 @@
|
|||
#include "fl/server/consistent_hash_ring.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
constexpr auto kModuleDistributedMetadataStore = "DistributedMetadataStore";
|
||||
// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly
|
||||
|
@ -44,10 +44,10 @@ class DistributedMetadataStore {
|
|||
}
|
||||
|
||||
// Initialize metadata storage with the server node because communication is needed.
|
||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
|
||||
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
|
||||
|
||||
// Register callbacks for the server to handle update/get metadata messages from other servers.
|
||||
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
||||
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
||||
|
||||
// Register metadata for the name with the initial value. This method should be only called once for each name.
|
||||
void RegisterMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
@ -65,7 +65,13 @@ class DistributedMetadataStore {
|
|||
bool ReInitForScaling();
|
||||
|
||||
private:
|
||||
DistributedMetadataStore() = default;
|
||||
DistributedMetadataStore()
|
||||
: server_node_(nullptr),
|
||||
communicator_(nullptr),
|
||||
local_rank_(0),
|
||||
server_num_(0),
|
||||
router_(nullptr),
|
||||
metadata_({}) {}
|
||||
~DistributedMetadataStore() = default;
|
||||
DistributedMetadataStore(const DistributedMetadataStore &) = delete;
|
||||
DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete;
|
||||
|
@ -74,17 +80,17 @@ class DistributedMetadataStore {
|
|||
void InitHashRing();
|
||||
|
||||
// Callback for updating metadata request sent to the server.
|
||||
void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Callback for getting metadata request sent to the server.
|
||||
void HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Do updating metadata in the server where the metadata for the name is stored.
|
||||
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
// Members for the communication between servers.
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
std::shared_ptr<core::TcpCommunicator> communicator_;
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
|
||||
uint32_t local_rank_;
|
||||
uint32_t server_num_;
|
||||
|
||||
|
@ -100,6 +106,6 @@ class DistributedMetadataStore {
|
|||
std::unordered_map<std::string, std::mutex> mutex_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -320,5 +320,5 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
|
|||
return true;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
@ -31,7 +31,7 @@
|
|||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
|
||||
// logics relevant to kernel launching.
|
||||
|
@ -94,7 +94,7 @@ class Executor {
|
|||
bool Unmask();
|
||||
|
||||
private:
|
||||
Executor() {}
|
||||
Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}) {}
|
||||
~Executor() = default;
|
||||
Executor(const Executor &) = delete;
|
||||
Executor &operator=(const Executor &) = delete;
|
||||
|
@ -126,6 +126,6 @@ class Executor {
|
|||
#endif
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
|
||||
|
|
|
@ -23,10 +23,10 @@
|
|||
#include "fl/server/server.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
class Server;
|
||||
void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator_ = communicator;
|
||||
communicator_->RegisterMsgCallBack("syncIteration",
|
||||
|
@ -42,12 +42,12 @@ void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunica
|
|||
std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1));
|
||||
}
|
||||
|
||||
void Iteration::RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node) {
|
||||
void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) {
|
||||
MS_EXCEPTION_IF_NULL(server_node);
|
||||
server_node_ = server_node;
|
||||
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning),
|
||||
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
|
||||
std::bind(&Iteration::HandleIterationRunningEvent, this));
|
||||
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted),
|
||||
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
|
||||
std::bind(&Iteration::HandleIterationCompletedEvent, this));
|
||||
}
|
||||
|
||||
|
@ -56,7 +56,7 @@ void Iteration::AddRound(const std::shared_ptr<Round> &round) {
|
|||
rounds_.push_back(round);
|
||||
}
|
||||
|
||||
void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
||||
void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
|
||||
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
|
||||
if (communicators.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
|
||||
|
@ -64,7 +64,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
|
|||
}
|
||||
|
||||
std::for_each(communicators.begin(), communicators.end(),
|
||||
[&](const std::shared_ptr<core::CommunicatorBase> &communicator) {
|
||||
[&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||
for (auto &round : rounds_) {
|
||||
if (round == nullptr) {
|
||||
continue;
|
||||
|
@ -120,7 +120,7 @@ void Iteration::SetIterationRunning() {
|
|||
}
|
||||
if (server_node_->rank_id() == kLeaderServerRank) {
|
||||
// This event helps worker/server to be consistent in iteration state.
|
||||
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationRunning));
|
||||
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning));
|
||||
}
|
||||
iteration_state_ = IterationState::kRunning;
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ void Iteration::SetIterationCompleted() {
|
|||
}
|
||||
if (server_node_->rank_id() == kLeaderServerRank) {
|
||||
// This event helps worker/server to be consistent in iteration state.
|
||||
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationCompleted));
|
||||
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted));
|
||||
}
|
||||
iteration_state_ = IterationState::kCompleted;
|
||||
}
|
||||
|
@ -171,7 +171,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
|||
sync_iter_req.set_rank(rank);
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
|
||||
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, core::TcpUserCommand::kSyncIteration,
|
||||
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration,
|
||||
&sync_iter_rsp_msg)) {
|
||||
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
|
||||
return false;
|
||||
|
@ -189,7 +189,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
|
@ -224,14 +224,14 @@ bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const s
|
|||
notify_leader_to_next_iter_req.set_iter_num(iteration_num_);
|
||||
notify_leader_to_next_iter_req.set_reason(reason);
|
||||
if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank,
|
||||
core::TcpUserCommand::kNotifyLeaderToNextIter)) {
|
||||
ps::core::TcpUserCommand::kNotifyLeaderToNextIter)) {
|
||||
MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -278,7 +278,7 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
|
|||
|
||||
std::vector<uint32_t> offline_servers = {};
|
||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||
if (!communicator_->SendPbRequest(prepare_next_iter_req, i, core::TcpUserCommand::kPrepareForNextIter)) {
|
||||
if (!communicator_->SendPbRequest(prepare_next_iter_req, i, ps::core::TcpUserCommand::kPrepareForNextIter)) {
|
||||
MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later.";
|
||||
offline_servers.push_back(i);
|
||||
continue;
|
||||
|
@ -289,17 +289,18 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
|
|||
std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
|
||||
// Should avoid endless loop if the server communicator is stopped.
|
||||
while (communicator_->running() &&
|
||||
!communicator_->SendPbRequest(prepare_next_iter_req, rank, core::TcpUserCommand::kPrepareForNextIter)) {
|
||||
!communicator_->SendPbRequest(prepare_next_iter_req, rank, ps::core::TcpUserCommand::kPrepareForNextIter)) {
|
||||
MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
|
||||
<< " failed. The server has not recovered yet.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));
|
||||
}
|
||||
MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
|
||||
});
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
return true;
|
||||
}
|
||||
|
||||
void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -329,7 +330,7 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
|
|||
proceed_to_next_iter_req.set_last_iter_num(iteration_num_);
|
||||
proceed_to_next_iter_req.set_reason(reason);
|
||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||
if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, core::TcpUserCommand::kProceedToNextIter)) {
|
||||
if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, ps::core::TcpUserCommand::kProceedToNextIter)) {
|
||||
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
|
||||
continue;
|
||||
}
|
||||
|
@ -339,7 +340,7 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
|
|||
return true;
|
||||
}
|
||||
|
||||
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -388,7 +389,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
|
|||
EndLastIterRequest end_last_iter_req;
|
||||
end_last_iter_req.set_last_iter_num(last_iter_num);
|
||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||
if (!communicator_->SendPbRequest(end_last_iter_req, i, core::TcpUserCommand::kEndLastIter)) {
|
||||
if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
|
||||
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
|
||||
continue;
|
||||
}
|
||||
|
@ -398,7 +399,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void Iteration::HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
@ -429,9 +430,9 @@ void Iteration::EndLastIter() {
|
|||
MS_LOG(INFO) << "End the last iteration " << iteration_num_;
|
||||
iteration_num_++;
|
||||
// After the job is done, reset the iteration to the initial number and reset ModelStore.
|
||||
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
|
||||
if (iteration_num_ > ps::PSContext::instance()->fl_iteration_num()) {
|
||||
MS_LOG(INFO) << "Iteration loop " << iteration_loop_count_
|
||||
<< " is completed. Iteration number: " << PSContext::instance()->fl_iteration_num();
|
||||
<< " is completed. Iteration number: " << ps::PSContext::instance()->fl_iteration_num();
|
||||
iteration_num_ = 1;
|
||||
iteration_loop_count_++;
|
||||
ModelStore::GetInstance().Reset();
|
||||
|
@ -444,5 +445,5 @@ void Iteration::EndLastIter() {
|
|||
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
@ -26,7 +26,7 @@
|
|||
#include "fl/server/local_meta_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
enum class IterationState {
|
||||
// This iteration is still in process.
|
||||
|
@ -48,16 +48,16 @@ class Iteration {
|
|||
}
|
||||
|
||||
// Register callbacks for other servers to synchronize iteration information from leader server.
|
||||
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
||||
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
||||
|
||||
// Register event callbacks for iteration state synchronization.
|
||||
void RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node);
|
||||
void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node);
|
||||
|
||||
// Add a round for the iteration. This method will be called multiple times for each round.
|
||||
void AddRound(const std::shared_ptr<Round> &round);
|
||||
|
||||
// Initialize all the rounds in the iteration.
|
||||
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
||||
void InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
|
||||
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
||||
|
||||
// This method will control servers to proceed to next iteration.
|
||||
|
@ -104,7 +104,7 @@ class Iteration {
|
|||
|
||||
// Synchronize iteration form the leader server(Rank 0).
|
||||
bool SyncIteration(uint32_t rank);
|
||||
void HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// The request for moving to next iteration is not reentrant.
|
||||
bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num);
|
||||
|
@ -112,28 +112,28 @@ class Iteration {
|
|||
// The methods for moving to next iteration for all the servers.
|
||||
// Step 1: follower servers notify leader server that they need to move to next iteration.
|
||||
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
|
||||
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode.
|
||||
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
||||
void HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
// The server prepare for the next iteration. This method will switch the server to safemode.
|
||||
void PrepareForNextIter();
|
||||
|
||||
// Step 3: leader server broadcast to all follower servers to move to next iteration.
|
||||
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
||||
void HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
// Move to next iteration. Store last iterations model and reset all the rounds.
|
||||
void Next(bool is_iteration_valid, const std::string &reason);
|
||||
|
||||
// Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode.
|
||||
bool BroadcastEndLastIterRequest(uint64_t iteration_num);
|
||||
void HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
// The server end the last iteration. This method will increase the iteration number and cancel the safemode.
|
||||
void EndLastIter();
|
||||
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
std::shared_ptr<core::TcpCommunicator> communicator_;
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
|
||||
|
||||
// All the rounds in the server.
|
||||
std::vector<std::shared_ptr<Round>> rounds_;
|
||||
|
@ -155,6 +155,6 @@ class Iteration {
|
|||
std::mutex pinned_mtx_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "fl/server/iteration_timer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
||||
if (running_.load()) {
|
||||
|
@ -46,11 +46,11 @@ void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) {
|
|||
return;
|
||||
}
|
||||
|
||||
bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) {
|
||||
bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) const {
|
||||
return timestamp > end_time_ ? true : false;
|
||||
}
|
||||
|
||||
bool IterationTimer::IsRunning() { return running_; }
|
||||
bool IterationTimer::IsRunning() const { return running_; }
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
|
||||
|
||||
#include <chrono>
|
||||
#include <atomic>
|
||||
|
@ -24,7 +24,7 @@
|
|||
#include "fl/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// IterationTimer controls the time window for the purpose of eliminating trailing time of each iteration.
|
||||
class IterationTimer {
|
||||
|
@ -42,10 +42,10 @@ class IterationTimer {
|
|||
void SetTimeOutCallBack(const TimeOutCb &timeout_cb);
|
||||
|
||||
// Judge whether current timestamp is out of time window's range since the Start function is called.
|
||||
bool IsTimeOut(const std::chrono::milliseconds ×tamp);
|
||||
bool IsTimeOut(const std::chrono::milliseconds ×tamp) const;
|
||||
|
||||
// Judge whether the timer is keeping timing.
|
||||
bool IsRunning();
|
||||
bool IsRunning() const;
|
||||
|
||||
private:
|
||||
// The running state for the timer.
|
||||
|
@ -59,6 +59,6 @@ class IterationTimer {
|
|||
TimeOutCb timeout_callback_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -26,7 +26,7 @@
|
|||
#include "fl/server/kernel/params_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation.
|
||||
|
@ -99,6 +99,6 @@ class AggregationKernel : public CPUKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
bool AggregationKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) {
|
||||
|
@ -67,5 +67,5 @@ bool AggregationKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNod
|
|||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -24,7 +24,7 @@
|
|||
#include "fl/server/kernel/aggregation_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>;
|
||||
|
@ -51,6 +51,7 @@ class AggregationKernelRegister {
|
|||
AggregationKernelCreator &&creator) {
|
||||
AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
|
||||
}
|
||||
~AggregationKernelRegister() = default;
|
||||
};
|
||||
|
||||
// Register aggregation kernel with one template type T.
|
||||
|
@ -66,6 +67,6 @@ class AggregationKernelRegister {
|
|||
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); });
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "fl/server/kernel/apply_momentum_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
REG_OPTIMIZER_KERNEL(ApplyMomentum,
|
||||
|
@ -30,5 +30,5 @@ REG_OPTIMIZER_KERNEL(ApplyMomentum,
|
|||
ApplyMomentumKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
@ -25,7 +25,7 @@
|
|||
#include "fl/server/kernel/optimizer_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using mindspore::kernel::ApplyMomentumCPUKernel;
|
||||
|
@ -57,6 +57,6 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "fl/server/kernel/dense_grad_accum_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
REG_AGGREGATION_KERNEL(
|
||||
|
@ -26,5 +26,5 @@ REG_AGGREGATION_KERNEL(
|
|||
DenseGradAccumKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -26,7 +26,7 @@
|
|||
#include "fl/server/kernel/aggregation_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
|
@ -90,6 +90,6 @@ class DenseGradAccumKernel : public AggregationKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "fl/server/kernel/fed_avg_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
REG_AGGREGATION_KERNEL_TWO(FedAvg,
|
||||
|
@ -29,5 +29,5 @@ REG_AGGREGATION_KERNEL_TWO(FedAvg,
|
|||
FedAvgKernel, float, size_t)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -31,7 +31,7 @@
|
|||
#include "fl/server/kernel/aggregation_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from
|
||||
|
@ -42,7 +42,13 @@ namespace kernel {
|
|||
template <typename T, typename S>
|
||||
class FedAvgKernel : public AggregationKernel {
|
||||
public:
|
||||
FedAvgKernel() : participated_(false) {}
|
||||
FedAvgKernel()
|
||||
: cnode_weight_idx_(0),
|
||||
weight_addr_(nullptr),
|
||||
data_size_addr_(nullptr),
|
||||
new_weight_addr_(nullptr),
|
||||
new_data_size_addr_(nullptr),
|
||||
participated_(false) {}
|
||||
~FedAvgKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override {
|
||||
|
@ -68,13 +74,13 @@ class FedAvgKernel : public AggregationKernel {
|
|||
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(weight_node);
|
||||
name_ = cnode_name + "." + weight_node->fullname_with_scope();
|
||||
first_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
first_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
||||
std::unique_lock<std::mutex> lock(weight_mutex_);
|
||||
if (!participated_) {
|
||||
ClearWeightAndDataSize();
|
||||
}
|
||||
};
|
||||
last_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
last_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
||||
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
|
||||
size_t weight_size = weight_addr_->size;
|
||||
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
|
||||
|
@ -193,7 +199,7 @@ class FedAvgKernel : public AggregationKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -26,7 +26,7 @@
|
|||
#include "fl/server/kernel/params_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory
|
||||
|
@ -87,6 +87,6 @@ class KernelFactory {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -28,7 +28,7 @@
|
|||
#include "fl/server/kernel/params_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using mindspore::kernel::IsSameShape;
|
||||
|
@ -92,6 +92,6 @@ class OptimizerKernel : public CPUKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) {
|
||||
|
@ -66,5 +66,5 @@ bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodeP
|
|||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -24,7 +24,7 @@
|
|||
#include "fl/server/kernel/optimizer_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using OptimizerKernelCreator = std::function<std::shared_ptr<OptimizerKernel>()>;
|
||||
|
@ -50,6 +50,7 @@ class OptimizerKernelRegister {
|
|||
OptimizerKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, OptimizerKernelCreator &&creator) {
|
||||
OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
|
||||
}
|
||||
~OptimizerKernelRegister() = default;
|
||||
};
|
||||
|
||||
// Register optimizer kernel with one template type T.
|
||||
|
@ -59,6 +60,6 @@ class OptimizerKernelRegister {
|
|||
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); });
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
ParamsInfo &ParamsInfo::AddInputNameType(const std::string &name, TypeId type) {
|
||||
|
@ -64,5 +64,5 @@ const std::vector<std::string> &ParamsInfo::workspace_names() const { return wor
|
|||
const std::vector<std::string> &ParamsInfo::outputs_names() const { return outputs_names_; }
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
@ -23,7 +23,7 @@
|
|||
#include "ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// ParamsInfo is used for server computation kernel's register, e.g, ApplyMomentumKernel, FedAvgKernel, etc.
|
||||
|
@ -65,6 +65,6 @@ class ParamsInfo {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "schema/cipher_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void ClientListKernel::InitKernel(size_t) {
|
||||
|
@ -150,7 +150,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
|
||||
return true;
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
|
||||
bool ClientListKernel::Reset() {
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
|
@ -196,5 +196,5 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli
|
|||
REG_ROUND_KERNEL(getClientList, ClientListKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
@ -26,7 +26,7 @@
|
|||
#include "fl/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class ClientListKernel : public RoundKernel {
|
||||
|
@ -50,6 +50,6 @@ class ClientListKernel : public RoundKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void ExchangeKeysKernel::InitKernel(size_t) {
|
||||
|
@ -100,5 +100,5 @@ bool ExchangeKeysKernel::Reset() {
|
|||
REG_ROUND_KERNEL(exchangeKeys, ExchangeKeysKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "fl/server/common.h"
|
||||
|
@ -25,7 +25,7 @@
|
|||
#include "fl/armour/cipher/cipher_keys.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class ExchangeKeysKernel : public RoundKernel {
|
||||
|
@ -44,7 +44,7 @@ class ExchangeKeysKernel : public RoundKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void GetKeysKernel::InitKernel(size_t) {
|
||||
|
@ -99,5 +99,5 @@ bool GetKeysKernel::Reset() {
|
|||
REG_ROUND_KERNEL(getKeys, GetKeysKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "fl/server/common.h"
|
||||
|
@ -25,7 +25,7 @@
|
|||
#include "fl/armour/cipher/cipher_keys.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class GetKeysKernel : public RoundKernel {
|
||||
|
@ -44,7 +44,7 @@ class GetKeysKernel : public RoundKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "fl/server/model_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void GetModelKernel::InitKernel(size_t) {
|
||||
|
@ -39,9 +39,8 @@ void GetModelKernel::InitKernel(size_t) {
|
|||
}
|
||||
}
|
||||
|
||||
bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(INFO) << "Launching GetModelKernel kernel.";
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
|
@ -49,6 +48,11 @@ bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
|
|||
return false;
|
||||
}
|
||||
|
||||
++retry_count_;
|
||||
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
|
||||
MS_LOG(INFO) << "Launching GetModelKernel retry count is " << retry_count_.load();
|
||||
}
|
||||
|
||||
const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
|
||||
GetModel(get_model_req, fbb);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
|
@ -58,6 +62,7 @@ bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
|
|||
bool GetModelKernel::Reset() {
|
||||
MS_LOG(INFO) << "Get model kernel reset!";
|
||||
StopTimer();
|
||||
retry_count_ = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -76,7 +81,9 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
|||
"2. Worker has not push all the weights to servers.";
|
||||
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
|
||||
std::to_string(next_req_time));
|
||||
MS_LOG(WARNING) << reason;
|
||||
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
|
||||
MS_LOG(WARNING) << reason;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -126,5 +133,5 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con
|
|||
REG_ROUND_KERNEL(getModel, GetModelKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -27,12 +27,13 @@
|
|||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
constexpr uint32_t kPrintGetModelForEveryRetryTime = 50;
|
||||
class GetModelKernel : public RoundKernel {
|
||||
public:
|
||||
GetModelKernel() = default;
|
||||
GetModelKernel() : executor_(nullptr), iteration_time_window_(0), retry_count_(0) {}
|
||||
~GetModelKernel() override = default;
|
||||
|
||||
void InitKernel(size_t) override;
|
||||
|
@ -51,9 +52,12 @@ class GetModelKernel : public RoundKernel {
|
|||
|
||||
// The time window of one iteration.
|
||||
size_t iteration_time_window_;
|
||||
|
||||
// The count of retrying because the iteration is not finished.
|
||||
std::atomic<uint64_t> retry_count_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "fl/armour/cipher/cipher_shares.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void GetSecretsKernel::InitKernel(size_t) {
|
||||
|
@ -102,5 +102,5 @@ bool GetSecretsKernel::Reset() {
|
|||
REG_ROUND_KERNEL(getSecrets, GetSecretsKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "fl/server/common.h"
|
||||
|
@ -25,7 +25,7 @@
|
|||
#include "fl/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class GetSecretsKernel : public RoundKernel {
|
||||
|
@ -44,7 +44,7 @@ class GetSecretsKernel : public RoundKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||
|
|
|
@ -22,10 +22,9 @@
|
|||
#include "fl/server/model_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
uint64_t PullWeightKernel::retry_count_ = 0;
|
||||
void PullWeightKernel::InitKernel(size_t) {
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
|
@ -35,7 +34,7 @@ void PullWeightKernel::InitKernel(size_t) {
|
|||
}
|
||||
}
|
||||
|
||||
bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(DEBUG) << "Launching PullWeightKernel kernel.";
|
||||
void *req_data = inputs[0]->addr;
|
||||
|
@ -59,13 +58,16 @@ bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
return true;
|
||||
}
|
||||
|
||||
bool PullWeightKernel::Reset() { return true; }
|
||||
bool PullWeightKernel::Reset() {
|
||||
retry_count_ = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPullWeight *pull_weight_req) {
|
||||
std::map<std::string, AddressPtr> feature_maps = {};
|
||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t pull_weight_iter = static_cast<size_t>(pull_weight_req->iteration());
|
||||
// The PullWeight round should be in the same iteration as other rounds.
|
||||
// The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry.
|
||||
if (pull_weight_iter != current_iter) {
|
||||
std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) +
|
||||
" is invalid. Server current iteration: " + std::to_string(current_iter);
|
||||
|
@ -76,15 +78,19 @@ void PullWeightKernel::PullWeight(std::shared_ptr<FBBuilder> fbb, const schema::
|
|||
|
||||
std::vector<std::string> weight_names = {};
|
||||
auto weights_names_fbs = pull_weight_req->weight_names();
|
||||
if (weights_names_fbs == nullptr) {
|
||||
MS_LOG(ERROR) << "weights_names_fbs is nullptr.";
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < weights_names_fbs->size(); i++) {
|
||||
weight_names.push_back(weights_names_fbs->Get(i)->str());
|
||||
}
|
||||
if (!executor_->IsWeightAggrDone(weight_names)) {
|
||||
retry_count_++;
|
||||
++retry_count_;
|
||||
std::string reason = "The aggregation for the weights is not done yet.";
|
||||
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps);
|
||||
if (retry_count_ % kPrintPullWeightForEveryRetryTime == 1) {
|
||||
MS_LOG(WARNING) << reason << " Retry count is " << retry_count_;
|
||||
if (retry_count_.load() % kPrintPullWeightForEveryRetryTime == 1) {
|
||||
MS_LOG(WARNING) << reason << " Retry count is " << retry_count_.load();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -131,5 +137,5 @@ void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const
|
|||
REG_ROUND_KERNEL(pullWeight, PullWeightKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -27,13 +27,13 @@
|
|||
#include "fl/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
constexpr uint32_t kPrintPullWeightForEveryRetryTime = 500;
|
||||
class PullWeightKernel : public RoundKernel {
|
||||
public:
|
||||
PullWeightKernel() = default;
|
||||
PullWeightKernel() : executor_(nullptr), retry_count_(0) {}
|
||||
~PullWeightKernel() override = default;
|
||||
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
|
@ -49,10 +49,10 @@ class PullWeightKernel : public RoundKernel {
|
|||
Executor *executor_;
|
||||
|
||||
// The count of retrying because the aggregation of the weights is not done.
|
||||
static uint64_t retry_count_;
|
||||
std::atomic<uint64_t> retry_count_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "fl/server/kernel/round/push_weight_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void PushWeightKernel::InitKernel(size_t) {
|
||||
|
@ -30,7 +30,7 @@ void PushWeightKernel::InitKernel(size_t) {
|
|||
local_rank_ = DistributedCountService::GetInstance().local_rank();
|
||||
}
|
||||
|
||||
bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
|
||||
void *req_data = inputs[0]->addr;
|
||||
|
@ -60,8 +60,8 @@ bool PushWeightKernel::Reset() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
||||
if (PSContext::instance()->resetter_round() == ResetterRound::kPushWeight) {
|
||||
void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
|
||||
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushWeight) {
|
||||
FinishIteration();
|
||||
}
|
||||
return;
|
||||
|
@ -111,6 +111,7 @@ std::map<std::string, Address> PushWeightKernel::ParseFeatureMap(const schema::R
|
|||
RETURN_IF_NULL(push_weight_req, {});
|
||||
std::map<std::string, Address> upload_feature_map;
|
||||
auto fbs_feature_map = push_weight_req->feature_map();
|
||||
RETURN_IF_NULL(push_weight_req, upload_feature_map);
|
||||
for (size_t i = 0; i < fbs_feature_map->size(); i++) {
|
||||
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
|
||||
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
|
||||
|
@ -135,5 +136,5 @@ void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const
|
|||
REG_ROUND_KERNEL(pushWeight, PushWeightKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -27,19 +27,19 @@
|
|||
#include "fl/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class PushWeightKernel : public RoundKernel {
|
||||
public:
|
||||
PushWeightKernel() = default;
|
||||
PushWeightKernel() : executor_(nullptr), local_rank_(0) {}
|
||||
~PushWeightKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
bool Reset() override;
|
||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
|
||||
|
@ -52,6 +52,6 @@ class PushWeightKernel : public RoundKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
|
||||
|
@ -34,17 +34,17 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
|
|||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
||||
MS_LOG(INFO) << "start FinishIteration";
|
||||
FinishIteration();
|
||||
MS_LOG(INFO) << "end FinishIteration";
|
||||
return;
|
||||
};
|
||||
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) { return; };
|
||||
auto first_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) { return; };
|
||||
name_unmask_ = "UnMaskKernel";
|
||||
MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : "
|
||||
<< LocalMetaStore::GetInstance().curr_iter_num();
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, PSContext::instance()->initial_server_num(),
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, ps::PSContext::instance()->initial_server_num(),
|
||||
{first_cnt_handler, last_cnt_handler});
|
||||
}
|
||||
|
||||
|
@ -134,9 +134,9 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
return true;
|
||||
}
|
||||
|
||||
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
if (PSContext::instance()->encrypt_type() == kPWEncryptType) {
|
||||
if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) {
|
||||
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
}
|
||||
|
@ -164,5 +164,5 @@ bool ReconstructSecretsKernel::Reset() {
|
|||
REG_ROUND_KERNEL(reconstructSecrets, ReconstructSecretsKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
@ -27,7 +27,7 @@
|
|||
#include "fl/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class ReconstructSecretsKernel : public RoundKernel {
|
||||
|
@ -39,7 +39,7 @@ class ReconstructSecretsKernel : public RoundKernel {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
std::string name_unmask_;
|
||||
|
@ -49,6 +49,6 @@ class ReconstructSecretsKernel : public RoundKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) {
|
||||
|
@ -61,25 +61,25 @@ RoundKernel::~RoundKernel() {
|
|||
}
|
||||
}
|
||||
|
||||
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }
|
||||
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
|
||||
|
||||
void RoundKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; }
|
||||
void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
|
||||
|
||||
void RoundKernel::StopTimer() {
|
||||
void RoundKernel::StopTimer() const {
|
||||
if (stop_timer_cb_) {
|
||||
stop_timer_cb_();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RoundKernel::FinishIteration() {
|
||||
void RoundKernel::FinishIteration() const {
|
||||
if (finish_iteration_cb_) {
|
||||
finish_iteration_cb_(true, "");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void RoundKernel::Release(AddressPtr addr_ptr) {
|
||||
void RoundKernel::Release(const AddressPtr &addr_ptr) {
|
||||
if (addr_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "Data to be released is empty.";
|
||||
return;
|
||||
|
@ -91,13 +91,13 @@ void RoundKernel::Release(AddressPtr addr_ptr) {
|
|||
|
||||
void RoundKernel::set_name(const std::string &name) { name_ = name; }
|
||||
|
||||
void RoundKernel::set_stop_timer_cb(StopTimerCb timer_stopper) { stop_timer_cb_ = timer_stopper; }
|
||||
void RoundKernel::set_stop_timer_cb(const StopTimerCb &timer_stopper) { stop_timer_cb_ = timer_stopper; }
|
||||
|
||||
void RoundKernel::set_finish_iteration_cb(FinishIterCb finish_iteration_cb) {
|
||||
void RoundKernel::set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb) {
|
||||
finish_iteration_cb_ = finish_iteration_cb;
|
||||
}
|
||||
|
||||
void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len) {
|
||||
void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const void *data, size_t len) {
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "The data is nullptr.";
|
||||
return;
|
||||
|
@ -129,5 +129,5 @@ void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, void *d
|
|||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -35,7 +35,7 @@
|
|||
#include "fl/server/distributed_metadata_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
|
||||
|
@ -67,31 +67,31 @@ class RoundKernel : virtual public CPUKernel {
|
|||
// The counter event handlers for DistributedCountService.
|
||||
// The callbacks when first message and last message for this round kernel is received.
|
||||
// These methods is called by class DistributedCountService and triggered by counting server.
|
||||
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
virtual void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
virtual void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Called when this round is finished. This round timer's Stop method will be called.
|
||||
void StopTimer();
|
||||
void StopTimer() const;
|
||||
|
||||
// Called after this iteration(including all rounds) is finished. All rounds' Reset method will
|
||||
// be called.
|
||||
void FinishIteration();
|
||||
void FinishIteration() const;
|
||||
|
||||
// Release the response data allocated inside the round kernel.
|
||||
// Server framework must call this after the response data is sent back.
|
||||
void Release(AddressPtr addr_ptr);
|
||||
void Release(const AddressPtr &addr_ptr);
|
||||
|
||||
// Set round kernel name, which could be used in round kernel's methods.
|
||||
void set_name(const std::string &name);
|
||||
|
||||
// Set callbacks to be called under certain triggered conditions.
|
||||
void set_stop_timer_cb(StopTimerCb timer_stopper);
|
||||
void set_finish_iteration_cb(FinishIterCb finish_iteration_cb);
|
||||
void set_stop_timer_cb(const StopTimerCb &timer_stopper);
|
||||
void set_finish_iteration_cb(const FinishIterCb &finish_iteration_cb);
|
||||
|
||||
protected:
|
||||
// Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent
|
||||
// back to worker.
|
||||
void GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len);
|
||||
void GenerateOutput(const std::vector<AddressPtr> &outputs, const void *data, size_t len);
|
||||
|
||||
// Round kernel's name.
|
||||
std::string name_;
|
||||
|
@ -123,6 +123,6 @@ class RoundKernel : virtual public CPUKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
RoundKernelFactory &RoundKernelFactory::GetInstance() {
|
||||
|
@ -40,5 +40,5 @@ std::shared_ptr<RoundKernel> RoundKernelFactory::Create(const std::string &name)
|
|||
}
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -25,7 +25,7 @@
|
|||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>;
|
||||
|
@ -50,6 +50,7 @@ class RoundKernelRegister {
|
|||
RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) {
|
||||
RoundKernelFactory::GetInstance().Register(name, std::move(creator));
|
||||
}
|
||||
~RoundKernelRegister() = default;
|
||||
};
|
||||
|
||||
#define REG_ROUND_KERNEL(NAME, CLASS) \
|
||||
|
@ -57,6 +58,6 @@ class RoundKernelRegister {
|
|||
static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); });
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void ShareSecretsKernel::InitKernel(size_t) {
|
||||
|
@ -101,5 +101,5 @@ bool ShareSecretsKernel::Reset() {
|
|||
REG_ROUND_KERNEL(shareSecrets, ShareSecretsKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "fl/server/common.h"
|
||||
|
@ -25,7 +25,7 @@
|
|||
#include "fl/armour/cipher/cipher_shares.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class ShareSecretsKernel : public RoundKernel {
|
||||
|
@ -44,7 +44,7 @@ class ShareSecretsKernel : public RoundKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void StartFLJobKernel::InitKernel(size_t) {
|
||||
|
@ -34,7 +34,7 @@ void StartFLJobKernel::InitKernel(size_t) {
|
|||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
|
||||
iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
|
||||
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
|
@ -49,7 +49,7 @@ void StartFLJobKernel::InitKernel(size_t) {
|
|||
return;
|
||||
}
|
||||
|
||||
bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(INFO) << "Launching StartFLJobKernel kernel.";
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
|
@ -113,7 +113,7 @@ bool StartFLJobKernel::Reset() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
||||
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
|
||||
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
|
||||
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||
// The first startFLJob request means a new iteration starts running.
|
||||
|
@ -133,6 +133,7 @@ bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuild
|
|||
}
|
||||
|
||||
DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) {
|
||||
RETURN_IF_NULL(start_fl_job_req, {});
|
||||
std::string fl_name = start_fl_job_req->fl_name()->str();
|
||||
std::string fl_id = start_fl_job_req->fl_id()->str();
|
||||
int data_size = start_fl_job_req->data_size();
|
||||
|
@ -141,7 +142,7 @@ DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *st
|
|||
DeviceMeta device_meta;
|
||||
device_meta.set_fl_name(fl_name);
|
||||
device_meta.set_fl_id(fl_id);
|
||||
device_meta.set_data_size(data_size);
|
||||
device_meta.set_data_size(IntToSize(data_size));
|
||||
return device_meta;
|
||||
}
|
||||
|
||||
|
@ -154,7 +155,7 @@ bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
|||
}
|
||||
if (!ret) {
|
||||
BuildStartFLJobRsp(
|
||||
fbb, schema::ResponseCode_NotSelected, reason, false,
|
||||
fbb, schema::ResponseCode_RequestError, reason, false,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
}
|
||||
|
@ -163,6 +164,7 @@ bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
|||
|
||||
bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestFLJob *start_fl_job_req) {
|
||||
RETURN_IF_NULL(start_fl_job_req, false);
|
||||
if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) {
|
||||
std::string reason = "Counting start fl job request failed. Please retry later.";
|
||||
BuildStartFLJobRsp(
|
||||
|
@ -192,8 +194,8 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
std::map<std::string, AddressPtr> feature_maps) {
|
||||
auto fbs_reason = fbb->CreateString(reason);
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode());
|
||||
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
|
||||
auto fbs_server_mode = fbb->CreateString(ps::PSContext::instance()->server_mode());
|
||||
auto fbs_fl_name = fbb->CreateString(ps::PSContext::instance()->fl_name());
|
||||
|
||||
#ifdef ENABLE_ARMOUR
|
||||
auto *param = armour::CipherInit::GetInstance().GetPublicParams();
|
||||
|
@ -204,7 +206,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
float dp_eps = param->dp_eps;
|
||||
float dp_delta = param->dp_delta;
|
||||
float dp_norm_clip = param->dp_norm_clip;
|
||||
auto encrypt_type = fbb->CreateString(PSContext::instance()->encrypt_type());
|
||||
auto encrypt_type = fbb->CreateString(ps::PSContext::instance()->encrypt_type());
|
||||
|
||||
auto cipher_public_params =
|
||||
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
|
||||
|
@ -213,10 +215,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||
fl_plan_builder.add_server_mode(fbs_server_mode);
|
||||
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
|
||||
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
|
||||
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
|
||||
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate());
|
||||
fl_plan_builder.add_iterations(ps::PSContext::instance()->fl_iteration_num());
|
||||
fl_plan_builder.add_epochs(ps::PSContext::instance()->client_epoch_num());
|
||||
fl_plan_builder.add_mini_batch(ps::PSContext::instance()->client_batch_size());
|
||||
fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate());
|
||||
#ifdef ENABLE_ARMOUR
|
||||
fl_plan_builder.add_cipher(cipher_public_params);
|
||||
#endif
|
||||
|
@ -235,7 +237,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get()));
|
||||
rsp_fl_job_builder.add_retcode(retcode);
|
||||
rsp_fl_job_builder.add_reason(fbs_reason);
|
||||
rsp_fl_job_builder.add_iteration(LocalMetaStore::GetInstance().curr_iter_num());
|
||||
rsp_fl_job_builder.add_iteration(SizeToInt(LocalMetaStore::GetInstance().curr_iter_num()));
|
||||
rsp_fl_job_builder.add_is_selected(is_selected);
|
||||
rsp_fl_job_builder.add_next_req_time(fbs_next_req_time);
|
||||
rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan);
|
||||
|
@ -248,5 +250,5 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
REG_ROUND_KERNEL(startFLJob, StartFLJobKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -27,7 +27,7 @@
|
|||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class StartFLJobKernel : public RoundKernel {
|
||||
|
@ -40,7 +40,7 @@ class StartFLJobKernel : public RoundKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
|
||||
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
// Returns whether the startFLJob count of this iteration has reached the threshold.
|
||||
|
@ -74,6 +74,6 @@ class StartFLJobKernel : public RoundKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "fl/server/kernel/round/update_model_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
||||
|
@ -42,7 +42,7 @@ void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
|||
LocalMetaStore::GetInstance().put_value(kCtxFedAvgTotalDataSize, kInitialDataSizeSum);
|
||||
}
|
||||
|
||||
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "inputs or outputs size is invalid.";
|
||||
|
@ -87,8 +87,8 @@ bool UpdateModelKernel::Reset() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
if (PSContext::instance()->resetter_round() == ResetterRound::kUpdateModel) {
|
||||
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) {
|
||||
while (!executor_->IsAllWeightAggregationDone()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
|
|||
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
|
||||
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
|
||||
<< total_data_size;
|
||||
if (PSContext::instance()->encrypt_type() != kPWEncryptType) {
|
||||
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
|
||||
FinishIteration();
|
||||
}
|
||||
}
|
||||
|
@ -116,6 +116,7 @@ bool UpdateModelKernel::ReachThresholdForUpdateModel(const std::shared_ptr<FBBui
|
|||
|
||||
bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_model_req,
|
||||
const std::shared_ptr<FBBuilder> &fbb) {
|
||||
RETURN_IF_NULL(update_model_req, false);
|
||||
size_t iteration = static_cast<size_t>(update_model_req->iteration());
|
||||
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
||||
|
@ -131,6 +132,7 @@ bool UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *update_mod
|
|||
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
|
||||
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
|
||||
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||
MS_LOG(INFO) << "Update model for fl id " << update_model_fl_id;
|
||||
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
||||
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
|
||||
BuildUpdateModelRsp(
|
||||
|
@ -180,6 +182,7 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
|
|||
RETURN_IF_NULL(update_model_req, {});
|
||||
std::map<std::string, UploadData> feature_map;
|
||||
auto fbs_feature_map = update_model_req->feature_map();
|
||||
RETURN_IF_NULL(fbs_feature_map, feature_map);
|
||||
for (size_t i = 0; i < fbs_feature_map->size(); i++) {
|
||||
std::string weight_full_name = fbs_feature_map->Get(i)->weight_fullname()->str();
|
||||
float *weight_data = const_cast<float *>(fbs_feature_map->Get(i)->data()->data());
|
||||
|
@ -194,6 +197,7 @@ std::map<std::string, UploadData> UpdateModelKernel::ParseFeatureMap(
|
|||
|
||||
bool UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestUpdateModel *update_model_req) {
|
||||
RETURN_IF_NULL(update_model_req, false);
|
||||
if (!DistributedCountService::GetInstance().Count(name_, update_model_req->fl_id()->str())) {
|
||||
std::string reason = "Counting for update model request failed. Please retry later.";
|
||||
BuildUpdateModelRsp(
|
||||
|
@ -222,5 +226,5 @@ void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fb
|
|||
REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -27,7 +27,7 @@
|
|||
#include "fl/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
|
||||
|
@ -35,7 +35,7 @@ constexpr uint64_t kInitialDataSizeSum = 0;
|
|||
|
||||
class UpdateModelKernel : public RoundKernel {
|
||||
public:
|
||||
UpdateModelKernel() = default;
|
||||
UpdateModelKernel() : executor_(nullptr), iteration_time_window_(0) {}
|
||||
~UpdateModelKernel() override = default;
|
||||
|
||||
void InitKernel(size_t threshold_count) override;
|
||||
|
@ -44,7 +44,7 @@ class UpdateModelKernel : public RoundKernel {
|
|||
bool Reset() override;
|
||||
|
||||
// In some cases, the last updateModel message means this server iteration is finished.
|
||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
|
||||
|
@ -62,6 +62,6 @@ class UpdateModelKernel : public RoundKernel {
|
|||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include "fl/server/local_meta_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
void LocalMetaStore::remove_value(const std::string &name) {
|
||||
std::unique_lock<std::mutex> lock(mtx_);
|
||||
|
@ -41,5 +41,5 @@ const size_t LocalMetaStore::curr_iter_num() {
|
|||
return curr_iter_num_;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
|
||||
|
||||
#include <any>
|
||||
#include <mutex>
|
||||
|
@ -24,7 +24,7 @@
|
|||
#include "fl/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// LocalMetaStore class is used for metadata storage of this server process.
|
||||
// For example, the current iteration number, time windows for round kernels, etc.
|
||||
|
@ -71,7 +71,7 @@ class LocalMetaStore {
|
|||
const size_t curr_iter_num();
|
||||
|
||||
private:
|
||||
LocalMetaStore() = default;
|
||||
LocalMetaStore() : key_to_meta_({}), curr_iter_num_(0) {}
|
||||
~LocalMetaStore() = default;
|
||||
LocalMetaStore(const LocalMetaStore &) = delete;
|
||||
LocalMetaStore &operator=(const LocalMetaStore &) = delete;
|
||||
|
@ -83,6 +83,6 @@ class LocalMetaStore {
|
|||
size_t curr_iter_num_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) {
|
||||
addresses_.try_emplace(name, address);
|
||||
|
@ -32,5 +32,5 @@ void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64
|
|||
|
||||
void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); }
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
@ -26,7 +26,7 @@
|
|||
#include "fl/server/common.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// Memory allocated in server is normally trainable parameters, hyperparameters, gradients, etc.
|
||||
// MemoryRegister registers the Memory with key-value format where key refers to address's name("grad", "weights",
|
||||
|
@ -88,6 +88,6 @@ class MemoryRegister {
|
|||
}
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "fl/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
void ModelStore::Initialize(uint32_t max_count) {
|
||||
if (!Executor::GetInstance().initialized()) {
|
||||
|
@ -155,5 +155,5 @@ size_t ModelStore::ComputeModelSize() {
|
|||
return model_size;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -25,7 +25,7 @@
|
|||
#include "fl/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// The initial iteration number is 0 in server.
|
||||
constexpr size_t kInitIterationNum = 0;
|
||||
|
@ -84,6 +84,6 @@ class ModelStore {
|
|||
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include <algorithm>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -199,8 +199,8 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
|
||||
if (PSContext::instance()->server_mode() == kServerModeFL ||
|
||||
PSContext::instance()->server_mode() == kServerModeHybrid) {
|
||||
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
||||
ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
|
||||
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
|
||||
return true;
|
||||
}
|
||||
|
@ -319,15 +319,15 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) {
|
||||
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
|
||||
std::vector<std::string> aggregation_algorithm = {};
|
||||
if (PSContext::instance()->server_mode() == kServerModeFL ||
|
||||
PSContext::instance()->server_mode() == kServerModeHybrid) {
|
||||
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
||||
ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
|
||||
aggregation_algorithm.push_back("FedAvg");
|
||||
} else if (PSContext::instance()->server_mode() == kServerModePS) {
|
||||
} else if (ps::PSContext::instance()->server_mode() == ps::kServerModePS) {
|
||||
aggregation_algorithm.push_back("DenseGradAccum");
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Server doesn't support mode " << PSContext::instance()->server_mode();
|
||||
MS_LOG(ERROR) << "Server doesn't support mode " << ps::PSContext::instance()->server_mode();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
|
||||
|
@ -344,5 +344,5 @@ template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::Aggregat
|
|||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||
std::shared_ptr<MemoryRegister> memory_register);
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -28,7 +28,7 @@
|
|||
#include "fl/server/kernel/optimizer_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server
|
||||
// kernels.
|
||||
|
@ -137,6 +137,6 @@ class ParameterAggregator {
|
|||
std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "fl/server/iteration.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
class Server;
|
||||
class Iteration;
|
||||
|
@ -34,14 +34,14 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo
|
|||
threshold_count_(threshold_count),
|
||||
server_num_as_threshold_(server_num_as_threshold) {}
|
||||
|
||||
void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||
FinishIterCb finish_iteration_cb) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator_ = communicator;
|
||||
|
||||
// Register callback for round kernel.
|
||||
communicator_->RegisterMsgCallBack(
|
||||
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||
name_, [&](std::shared_ptr<ps::core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||
|
||||
// Callback when the iteration is finished.
|
||||
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
|
||||
|
@ -106,7 +106,7 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
|
|||
return;
|
||||
}
|
||||
|
||||
void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
if (message == nullptr) {
|
||||
MS_LOG(ERROR) << "Message is nullptr.";
|
||||
return;
|
||||
|
@ -152,7 +152,7 @@ bool Round::check_timeout() const { return check_timeout_; }
|
|||
|
||||
size_t Round::time_window() const { return time_window_; }
|
||||
|
||||
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
|
||||
// The timer starts only after the first count event is triggered by DistributedCountService.
|
||||
if (check_timeout_) {
|
||||
|
@ -164,7 +164,7 @@ void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &messa
|
|||
return;
|
||||
}
|
||||
|
||||
void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
|
||||
// Same as the first count event, the timer must be stopped by DistributedCountService.
|
||||
if (check_timeout_) {
|
||||
|
@ -176,5 +176,5 @@ void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &messag
|
|||
return;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -26,7 +26,7 @@
|
|||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// Round helps server to handle network round messages and launch round kernels. One iteration in server consists of
|
||||
// multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting
|
||||
|
@ -37,7 +37,7 @@ class Round {
|
|||
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
|
||||
~Round() = default;
|
||||
|
||||
void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||
FinishIterCb finish_iteration_cb);
|
||||
|
||||
// Reinitialize count service and round kernel of this round after scaling operations are done.
|
||||
|
@ -48,7 +48,7 @@ class Round {
|
|||
|
||||
// This method is the callback which will be set to the communicator and called after the corresponding round message
|
||||
// is sent to the server.
|
||||
void LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
// Round needs to be reset after each iteration is finished or its timer expires.
|
||||
void Reset();
|
||||
|
@ -60,8 +60,8 @@ class Round {
|
|||
|
||||
private:
|
||||
// The callbacks which will be set to DistributedCounterService.
|
||||
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
||||
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||
|
||||
std::string name_;
|
||||
|
||||
|
@ -83,7 +83,7 @@ class Round {
|
|||
// Whether this round uses the server number as its threshold count.
|
||||
bool server_num_as_threshold_;
|
||||
|
||||
std::shared_ptr<core::CommunicatorBase> communicator_;
|
||||
std::shared_ptr<ps::core::CommunicatorBase> communicator_;
|
||||
|
||||
// The round kernel for this Round.
|
||||
std::shared_ptr<kernel::RoundKernel> kernel_;
|
||||
|
@ -97,6 +97,6 @@ class Round {
|
|||
FinalizeCb finalize_cb_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
|
||||
|
|
|
@ -30,17 +30,8 @@
|
|||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
static std::vector<std::shared_ptr<core::CommunicatorBase>> global_worker_server_comms = {};
|
||||
// This function is for the exit of server process when an interrupt signal is captured.
|
||||
void SignalHandler(int signal) {
|
||||
MS_LOG(INFO) << "Interrupt signal captured: " << signal;
|
||||
std::for_each(global_worker_server_comms.begin(), global_worker_server_comms.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
return;
|
||||
}
|
||||
|
||||
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -76,7 +67,6 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
|
|||
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
|
||||
// InitCipher---->InitExecutor
|
||||
void Server::Run() {
|
||||
signal(SIGINT, SignalHandler);
|
||||
std::unique_lock<std::mutex> lock(scaling_mtx_);
|
||||
InitServerContext();
|
||||
InitCluster();
|
||||
|
@ -84,8 +74,8 @@ void Server::Run() {
|
|||
RegisterCommCallbacks();
|
||||
StartCommunicator();
|
||||
InitExecutor();
|
||||
std::string encrypt_type = PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type != kNotEncryptType) {
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type != ps::kNotEncryptType) {
|
||||
InitCipher();
|
||||
MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
|
||||
}
|
||||
|
@ -96,7 +86,7 @@ void Server::Run() {
|
|||
|
||||
// Wait communicators to stop so the main thread is blocked.
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Join(); });
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Join(); });
|
||||
communicator_with_server_->Join();
|
||||
MsException::Instance().CheckException();
|
||||
return;
|
||||
|
@ -115,18 +105,18 @@ void Server::CancelSafeMode() {
|
|||
bool Server::IsSafeMode() { return safemode_.load(); }
|
||||
|
||||
void Server::InitServerContext() {
|
||||
PSContext::instance()->GenerateResetterRound();
|
||||
scheduler_ip_ = PSContext::instance()->scheduler_host();
|
||||
scheduler_port_ = PSContext::instance()->scheduler_port();
|
||||
worker_num_ = PSContext::instance()->initial_worker_num();
|
||||
server_num_ = PSContext::instance()->initial_server_num();
|
||||
ps::PSContext::instance()->GenerateResetterRound();
|
||||
scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
|
||||
scheduler_port_ = ps::PSContext::instance()->scheduler_port();
|
||||
worker_num_ = ps::PSContext::instance()->initial_worker_num();
|
||||
server_num_ = ps::PSContext::instance()->initial_server_num();
|
||||
return;
|
||||
}
|
||||
|
||||
void Server::InitCluster() {
|
||||
server_node_ = std::make_shared<core::ServerNode>();
|
||||
server_node_ = std::make_shared<ps::core::ServerNode>();
|
||||
MS_EXCEPTION_IF_NULL(server_node_);
|
||||
task_executor_ = std::make_shared<core::TaskExecutor>(32);
|
||||
task_executor_ = std::make_shared<ps::core::TaskExecutor>(32);
|
||||
MS_EXCEPTION_IF_NULL(task_executor_);
|
||||
if (!InitCommunicatorWithServer()) {
|
||||
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
|
||||
|
@ -136,7 +126,6 @@ void Server::InitCluster() {
|
|||
MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
|
||||
return;
|
||||
}
|
||||
global_worker_server_comms = communicators_with_worker_;
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -187,8 +176,8 @@ void Server::InitIteration() {
|
|||
}
|
||||
|
||||
#ifdef ENABLE_ARMOUR
|
||||
std::string encrypt_type = PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type == kPWEncryptType) {
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type == ps::kPWEncryptType) {
|
||||
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
|
||||
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
|
||||
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
|
||||
|
@ -245,10 +234,10 @@ void Server::InitCipher() {
|
|||
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
|
||||
int cipher_g = 1;
|
||||
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
|
||||
float dp_eps = PSContext::instance()->dp_eps();
|
||||
float dp_delta = PSContext::instance()->dp_delta();
|
||||
float dp_norm_clip = PSContext::instance()->dp_norm_clip();
|
||||
std::string encrypt_type = PSContext::instance()->encrypt_type();
|
||||
float dp_eps = ps::PSContext::instance()->dp_eps();
|
||||
float dp_delta = ps::PSContext::instance()->dp_delta();
|
||||
float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
|
||||
mpz_t prim;
|
||||
mpz_init(prim);
|
||||
|
@ -276,7 +265,7 @@ void Server::RegisterCommCallbacks() {
|
|||
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
|
||||
// rounds' callbacks.
|
||||
|
||||
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
|
||||
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||
|
||||
// Set message callbacks for server-to-server communication.
|
||||
|
@ -304,23 +293,23 @@ void Server::RegisterCommCallbacks() {
|
|||
std::bind(&Server::ProcessAfterScalingIn, this));
|
||||
}
|
||||
|
||||
void Server::RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
||||
void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||
MS_EXCEPTION_IF_NULL(communicator);
|
||||
communicator->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
|
||||
communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||
safemode_ = true;
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
|
||||
communicator->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [&]() {
|
||||
communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
|
||||
MS_LOG(ERROR)
|
||||
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
||||
"network building phase.";
|
||||
safemode_ = true;
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
});
|
||||
}
|
||||
|
@ -377,7 +366,7 @@ void Server::StartCommunicator() {
|
|||
|
||||
MS_LOG(INFO) << "Start communicator with worker.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Start(); });
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Start(); });
|
||||
}
|
||||
|
||||
void Server::ProcessBeforeScalingOut() {
|
||||
|
@ -398,23 +387,18 @@ void Server::ProcessAfterScalingOut() {
|
|||
|
||||
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!DistributedCountService::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
|
||||
MS_LOG(WARNING) << "Iteration reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
safemode_ = false;
|
||||
|
@ -429,7 +413,7 @@ void Server::ProcessAfterScalingIn() {
|
|||
if (server_node_->rank_id() == UINT32_MAX) {
|
||||
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
|
||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||
communicator_with_server_->Stop();
|
||||
return;
|
||||
}
|
||||
|
@ -437,27 +421,22 @@ void Server::ProcessAfterScalingIn() {
|
|||
// If the server is not the one to be scaled in, reintialize modules and recover service.
|
||||
if (!DistributedMetadataStore::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!CollectiveOpsImpl::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "DistributedMetadataStore reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!DistributedCountService::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "DistributedCountService reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!iteration_->ReInitForScaling(IntToUint(server_node_->server_num()), server_node_->rank_id())) {
|
||||
MS_LOG(WARNING) << "Iteration reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
safemode_ = false;
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
|
||||
#define MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -31,7 +31,7 @@
|
|||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace server {
|
||||
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
|
||||
class Server {
|
||||
|
@ -90,7 +90,7 @@ class Server {
|
|||
void RegisterCommCallbacks();
|
||||
|
||||
// Register cluster exception callbacks. This method is called in RegisterCommCallbacks.
|
||||
void RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
||||
void RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
||||
|
||||
// Initialize executor according to the server mode.
|
||||
void InitExecutor();
|
||||
|
@ -113,11 +113,11 @@ class Server {
|
|||
void ProcessAfterScalingIn();
|
||||
|
||||
// The server node is initialized in Server.
|
||||
std::shared_ptr<core::ServerNode> server_node_;
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
|
||||
// The task executor of the communicators. This helps server to handle network message concurrently. The tasks
|
||||
// submitted to this task executor is asynchronous.
|
||||
std::shared_ptr<core::TaskExecutor> task_executor_;
|
||||
std::shared_ptr<ps::core::TaskExecutor> task_executor_;
|
||||
|
||||
// Which protocol should communicators use.
|
||||
bool use_tcp_;
|
||||
|
@ -136,12 +136,12 @@ class Server {
|
|||
|
||||
// Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
|
||||
// operations, etc.
|
||||
std::shared_ptr<core::CommunicatorBase> communicator_with_server_;
|
||||
std::shared_ptr<ps::core::CommunicatorBase> communicator_with_server_;
|
||||
|
||||
// The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
|
||||
// In some cases, both types should be supported in one distributed training job. So here we may have multiple
|
||||
// communicators.
|
||||
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
|
||||
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
|
||||
|
||||
// Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
|
||||
std::mutex scaling_mtx_;
|
||||
|
@ -176,6 +176,6 @@ class Server {
|
|||
float percent_for_get_model_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
|
||||
|
|
|
@ -22,26 +22,27 @@
|
|||
#include "utils/ms_exception.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
namespace worker {
|
||||
void FLWorker::Run() {
|
||||
worker_num_ = PSContext::instance()->worker_num();
|
||||
server_num_ = PSContext::instance()->server_num();
|
||||
scheduler_ip_ = PSContext::instance()->scheduler_ip();
|
||||
scheduler_port_ = PSContext::instance()->scheduler_port();
|
||||
worker_step_num_per_iteration_ = PSContext::instance()->worker_step_num_per_iteration();
|
||||
PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_;
|
||||
PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
|
||||
PSContext::instance()->cluster_config().initial_worker_num = worker_num_;
|
||||
PSContext::instance()->cluster_config().initial_server_num = server_num_;
|
||||
worker_num_ = ps::PSContext::instance()->worker_num();
|
||||
server_num_ = ps::PSContext::instance()->server_num();
|
||||
scheduler_ip_ = ps::PSContext::instance()->scheduler_ip();
|
||||
scheduler_port_ = ps::PSContext::instance()->scheduler_port();
|
||||
worker_step_num_per_iteration_ = ps::PSContext::instance()->worker_step_num_per_iteration();
|
||||
ps::PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_;
|
||||
ps::PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
|
||||
ps::PSContext::instance()->cluster_config().initial_worker_num = worker_num_;
|
||||
ps::PSContext::instance()->cluster_config().initial_server_num = server_num_;
|
||||
MS_LOG(INFO) << "Initialize cluster config for worker. Worker number:" << worker_num_
|
||||
<< ", Server number:" << server_num_ << ", Scheduler ip:" << scheduler_ip_
|
||||
<< ", Scheduler port:" << scheduler_port_;
|
||||
<< ", Scheduler port:" << scheduler_port_
|
||||
<< ", Worker training step per iteration:" << worker_step_num_per_iteration_;
|
||||
|
||||
worker_node_ = std::make_shared<core::WorkerNode>();
|
||||
worker_node_ = std::make_shared<ps::core::WorkerNode>();
|
||||
MS_EXCEPTION_IF_NULL(worker_node_);
|
||||
|
||||
worker_node_->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
|
||||
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
|
||||
Finalize();
|
||||
try {
|
||||
MS_LOG(EXCEPTION)
|
||||
|
@ -50,7 +51,7 @@ void FLWorker::Run() {
|
|||
MsException::Instance().SetException();
|
||||
}
|
||||
});
|
||||
worker_node_->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() {
|
||||
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
|
||||
Finalize();
|
||||
try {
|
||||
MS_LOG(EXCEPTION)
|
||||
|
@ -73,7 +74,7 @@ void FLWorker::Finalize() {
|
|||
worker_node_->Stop();
|
||||
}
|
||||
|
||||
bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core::TcpUserCommand command,
|
||||
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
|
||||
std::shared_ptr<std::vector<unsigned char>> *output) {
|
||||
// If the worker is in safemode, do not communicate with server.
|
||||
while (safemode_.load()) {
|
||||
|
@ -96,7 +97,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core:
|
|||
|
||||
if (output != nullptr) {
|
||||
while (true) {
|
||||
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output)) {
|
||||
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
|
||||
output)) {
|
||||
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -105,7 +107,7 @@ bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core:
|
|||
return false;
|
||||
}
|
||||
|
||||
if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == kClusterSafeMode) {
|
||||
if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == ps::kClusterSafeMode) {
|
||||
MS_LOG(INFO) << "The server " << server_rank << " is in safemode.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode));
|
||||
} else {
|
||||
|
@ -113,7 +115,7 @@ bool FLWorker::SendToServer(uint32_t server_rank, void *data, size_t size, core:
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
|
||||
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
|
||||
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
@ -147,16 +149,16 @@ void FLWorker::InitializeFollowerScaler() {
|
|||
worker_node_->RegisterFollowerScalerBarrierBeforeScaleOut("WorkerPipeline",
|
||||
std::bind(&FLWorker::ProcessBeforeScalingOut, this));
|
||||
worker_node_->RegisterFollowerScalerBarrierBeforeScaleIn("WorkerPipeline",
|
||||
std::bind(&FLWorker::ProcessBeforeScalingOut, this));
|
||||
std::bind(&FLWorker::ProcessBeforeScalingIn, this));
|
||||
|
||||
// Set handlers after scheduler scaling operations are done.
|
||||
worker_node_->RegisterFollowerScalerHandlerAfterScaleOut("WorkerPipeline",
|
||||
std::bind(&FLWorker::ProcessAfterScalingOut, this));
|
||||
worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
|
||||
std::bind(&FLWorker::ProcessAfterScalingIn, this));
|
||||
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning),
|
||||
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
|
||||
std::bind(&FLWorker::HandleIterationRunningEvent, this));
|
||||
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted),
|
||||
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
|
||||
std::bind(&FLWorker::HandleIterationCompletedEvent, this));
|
||||
}
|
||||
|
||||
|
@ -221,5 +223,5 @@ void FLWorker::ProcessAfterScalingIn() {
|
|||
safemode_ = false;
|
||||
}
|
||||
} // namespace worker
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_
|
||||
#define MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_
|
||||
#ifndef MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
|
||||
#define MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -28,14 +28,14 @@
|
|||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace fl {
|
||||
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
||||
|
||||
// The step number for worker to judge whether to communicate with server.
|
||||
constexpr uint32_t kTrainBeginStepNum = 1;
|
||||
constexpr uint32_t kTrainEndStepNum = 0;
|
||||
|
||||
// The worker has to sleep for a while before the networking is completed.
|
||||
// The sleeping time of the worker thread before the networking is completed.
|
||||
constexpr uint32_t kWorkerSleepTimeForNetworking = 1000;
|
||||
|
||||
// The time duration between retrying when server is in safemode.
|
||||
|
@ -59,7 +59,7 @@ class FLWorker {
|
|||
}
|
||||
void Run();
|
||||
void Finalize();
|
||||
bool SendToServer(uint32_t server_rank, void *data, size_t size, core::TcpUserCommand command,
|
||||
bool SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
|
||||
std::shared_ptr<std::vector<unsigned char>> *output = nullptr);
|
||||
|
||||
uint32_t server_num() const;
|
||||
|
@ -104,10 +104,9 @@ class FLWorker {
|
|||
uint32_t worker_num_;
|
||||
std::string scheduler_ip_;
|
||||
uint16_t scheduler_port_;
|
||||
std::shared_ptr<core::WorkerNode> worker_node_;
|
||||
std::shared_ptr<ps::core::WorkerNode> worker_node_;
|
||||
|
||||
// The worker standalone training step number before communicating with server. This used in hybrid training mode for
|
||||
// now.
|
||||
// The worker standalone training step number before communicating with server. This used in hybrid training mode.
|
||||
uint64_t worker_step_num_per_iteration_;
|
||||
|
||||
// The iteration state is either running or completed.
|
||||
|
@ -115,13 +114,13 @@ class FLWorker {
|
|||
// kIterationRunning/kIterationCompleted. triggered by server.
|
||||
std::atomic<IterationState> server_iteration_state_;
|
||||
|
||||
// The variable represents the worker iteration state and should be changed by worker training process.
|
||||
// This variable represents the worker iteration state and should be changed by worker training process.
|
||||
std::atomic<IterationState> worker_iteration_state_;
|
||||
|
||||
// The flag that represents whether worker is in safemode, which is decided by both worker and server iteration state.
|
||||
std::atomic_bool safemode_;
|
||||
};
|
||||
} // namespace worker
|
||||
} // namespace ps
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_
|
||||
#endif // MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
|
||||
|
|
|
@ -639,7 +639,7 @@ bool StartPSWorkerAction(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
bool StartFLWorkerAction(const ResourcePtr &) {
|
||||
ps::worker::FLWorker::GetInstance().Run();
|
||||
fl::worker::FLWorker::GetInstance().Run();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -665,7 +665,7 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
|
||||
uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
|
||||
|
||||
std::vector<ps::server::RoundConfig> rounds_config = {
|
||||
std::vector<fl::server::RoundConfig> rounds_config = {
|
||||
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
|
||||
{"updateModel", true, update_model_time_window, true, update_model_threshold},
|
||||
{"getModel"},
|
||||
|
@ -676,22 +676,22 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
|
||||
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
|
||||
|
||||
ps::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
|
||||
fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
|
||||
|
||||
size_t executor_threshold = 0;
|
||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||
executor_threshold = update_model_threshold;
|
||||
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
|
||||
fl::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
|
||||
executor_threshold);
|
||||
} else if (server_mode_ == ps::kServerModePS) {
|
||||
executor_threshold = worker_num;
|
||||
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
|
||||
fl::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
|
||||
executor_threshold);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
|
||||
return false;
|
||||
}
|
||||
ps::server::Server::GetInstance().Run();
|
||||
fl::server::Server::GetInstance().Run();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -1293,7 +1293,7 @@ void ClearResAtexit() {
|
|||
MS_LOG(INFO) << "Start finalizing worker.";
|
||||
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
||||
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid)) {
|
||||
ps::worker::FLWorker::GetInstance().Finalize();
|
||||
fl::worker::FLWorker::GetInstance().Finalize();
|
||||
} else {
|
||||
ps::Worker::GetInstance().Finalize();
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
syntax = "proto3";
|
||||
package mindspore.ps;
|
||||
package mindspore.fl;
|
||||
|
||||
message CollectiveData {
|
||||
bytes data = 1;
|
||||
|
|
|
@ -286,8 +286,9 @@ void PSContext::GenerateResetterRound() {
|
|||
return;
|
||||
}
|
||||
|
||||
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
|
||||
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3);
|
||||
binary_server_context = ((unsigned int)is_parameter_server_mode << 0) |
|
||||
((unsigned int)is_federated_learning_mode << 1) |
|
||||
((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)secure_aggregation_ << 3);
|
||||
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
|
||||
resetter_round_ = ResetterRound::kNoNeedToReset;
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue