forked from mindspore-Ecosystem/mindspore
!19689 Fix fl namespace issue.
Merge pull request !19689 from ZPaC/fix-namespace
This commit is contained in:
commit
d76bb99d8a
|
@ -47,13 +47,13 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ps::FBBuilder> fbb = std::make_shared<ps::FBBuilder>();
|
std::shared_ptr<fl::FBBuilder> fbb = std::make_shared<fl::FBBuilder>();
|
||||||
MS_EXCEPTION_IF_NULL(fbb);
|
MS_EXCEPTION_IF_NULL(fbb);
|
||||||
|
|
||||||
total_iteration_++;
|
total_iteration_++;
|
||||||
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
||||||
if (total_iteration_ % ps::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
||||||
ps::kTrainBeginStepNum) {
|
fl::kTrainBeginStepNum) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,10 +72,10 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
|
const schema::ResponsePullWeight *pull_weight_rsp = nullptr;
|
||||||
int retcode = schema::ResponseCode_SucNotReady;
|
int retcode = schema::ResponseCode_SucNotReady;
|
||||||
while (retcode == schema::ResponseCode_SucNotReady) {
|
while (retcode == schema::ResponseCode_SucNotReady) {
|
||||||
if (!ps::worker::FLWorker::GetInstance().SendToServer(
|
if (!fl::worker::FLWorker::GetInstance().SendToServer(
|
||||||
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
|
0, fbb->GetBufferPointer(), fbb->GetSize(), ps::core::TcpUserCommand::kPullWeight, &pull_weight_rsp_msg)) {
|
||||||
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. This iteration is dropped.";
|
MS_LOG(WARNING) << "Sending request for FusedPullWeight to server 0 failed. This iteration is dropped.";
|
||||||
ps::worker::FLWorker::GetInstance().SetIterationRunning();
|
fl::worker::FLWorker::GetInstance().SetIterationRunning();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
|
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
|
||||||
|
@ -116,7 +116,7 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
MS_LOG(INFO) << "Pull weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
||||||
ps::worker::FLWorker::GetInstance().SetIterationRunning();
|
fl::worker::FLWorker::GetInstance().SetIterationRunning();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
void InitSizeLists() { return; }
|
void InitSizeLists() { return; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool BuildPullWeightReq(std::shared_ptr<ps::FBBuilder> fbb) {
|
bool BuildPullWeightReq(std::shared_ptr<fl::FBBuilder> fbb) {
|
||||||
MS_EXCEPTION_IF_NULL(fbb);
|
MS_EXCEPTION_IF_NULL(fbb);
|
||||||
std::vector<flatbuffers::Offset<flatbuffers::String>> fbs_weight_names;
|
std::vector<flatbuffers::Offset<flatbuffers::String>> fbs_weight_names;
|
||||||
for (const std::string &weight_name : weight_full_names_) {
|
for (const std::string &weight_name : weight_full_names_) {
|
||||||
|
|
|
@ -45,13 +45,13 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ps::FBBuilder> fbb = std::make_shared<ps::FBBuilder>();
|
std::shared_ptr<fl::FBBuilder> fbb = std::make_shared<fl::FBBuilder>();
|
||||||
MS_EXCEPTION_IF_NULL(fbb);
|
MS_EXCEPTION_IF_NULL(fbb);
|
||||||
|
|
||||||
total_iteration_++;
|
total_iteration_++;
|
||||||
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
||||||
if (total_iteration_ % ps::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
if (total_iteration_ % fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration() !=
|
||||||
ps::kTrainBeginStepNum) {
|
fl::kTrainBeginStepNum) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,17 +67,17 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
// The server number may change after scaling in/out.
|
// The server number may change after scaling in/out.
|
||||||
for (uint32_t i = 0; i < ps::worker::FLWorker::GetInstance().server_num(); i++) {
|
for (uint32_t i = 0; i < fl::worker::FLWorker::GetInstance().server_num(); i++) {
|
||||||
std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr;
|
std::shared_ptr<std::vector<unsigned char>> push_weight_rsp_msg = nullptr;
|
||||||
const schema::ResponsePushWeight *push_weight_rsp = nullptr;
|
const schema::ResponsePushWeight *push_weight_rsp = nullptr;
|
||||||
int retcode = schema::ResponseCode_SucNotReady;
|
int retcode = schema::ResponseCode_SucNotReady;
|
||||||
while (retcode == schema::ResponseCode_SucNotReady) {
|
while (retcode == schema::ResponseCode_SucNotReady) {
|
||||||
if (!ps::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
|
if (!fl::worker::FLWorker::GetInstance().SendToServer(i, fbb->GetBufferPointer(), fbb->GetSize(),
|
||||||
ps::core::TcpUserCommand::kPushWeight,
|
ps::core::TcpUserCommand::kPushWeight,
|
||||||
&push_weight_rsp_msg)) {
|
&push_weight_rsp_msg)) {
|
||||||
MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i
|
MS_LOG(WARNING) << "Sending request for FusedPushWeight to server " << i
|
||||||
<< " failed. This iteration is dropped.";
|
<< " failed. This iteration is dropped.";
|
||||||
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
|
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
|
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
|
||||||
|
@ -105,7 +105,7 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
MS_LOG(INFO) << "Push weights for " << weight_full_names_ << " succeed. Iteration: " << fl_iteration_;
|
||||||
ps::worker::FLWorker::GetInstance().SetIterationCompleted();
|
fl::worker::FLWorker::GetInstance().SetIterationCompleted();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
void InitSizeLists() { return; }
|
void InitSizeLists() { return; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool BuildPushWeightReq(std::shared_ptr<ps::FBBuilder> fbb, const std::vector<AddressPtr> &weights) {
|
bool BuildPushWeightReq(std::shared_ptr<fl::FBBuilder> fbb, const std::vector<AddressPtr> &weights) {
|
||||||
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||||
for (size_t i = 0; i < weight_full_names_.size(); i++) {
|
for (size_t i = 0; i < weight_full_names_.size(); i++) {
|
||||||
const std::string &weight_name = weight_full_names_[i];
|
const std::string &weight_name = weight_full_names_[i];
|
||||||
|
|
|
@ -31,8 +31,8 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
||||||
int return_num = 0;
|
int return_num = 0;
|
||||||
cipher_meta_storage_.RegisterClass();
|
cipher_meta_storage_.RegisterClass();
|
||||||
const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN);
|
const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN);
|
||||||
cipher_meta_storage_.RegisterPrime(ps::server::kCtxCipherPrimer, new_prime);
|
cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime);
|
||||||
if (!cipher_meta_storage_.GetPrimeFromServer(ps::server::kCtxCipherPrimer, publicparam_.prime)) {
|
if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) {
|
||||||
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
|
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -45,7 +45,7 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
||||||
publicparam_.t = param.t;
|
publicparam_.t = param.t;
|
||||||
secrets_minnums_ = param.t;
|
secrets_minnums_ = param.t;
|
||||||
client_num_need_ = cipher_initial_client_cnt;
|
client_num_need_ = cipher_initial_client_cnt;
|
||||||
featuremap_ = ps::server::ModelStore::GetInstance().model_size() / sizeof(float);
|
featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float);
|
||||||
share_clients_num_need_ = cipher_share_secrets_cnt;
|
share_clients_num_need_ = cipher_share_secrets_cnt;
|
||||||
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
|
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
|
||||||
get_model_num_need_ = cipher_get_clientlist_cnt;
|
get_model_num_need_ = cipher_get_clientlist_cnt;
|
||||||
|
|
|
@ -21,7 +21,7 @@ namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
|
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::GetExchangeKeys *get_exchange_keys_req,
|
const schema::GetExchangeKeys *get_exchange_keys_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> get_exchange_keys_resp_builder) {
|
std::shared_ptr<fl::server::FBBuilder> get_exchange_keys_resp_builder) {
|
||||||
MS_LOG(INFO) << "CipherMgr::GetKeys START";
|
MS_LOG(INFO) << "CipherMgr::GetKeys START";
|
||||||
if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) {
|
if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) {
|
||||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
|
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
|
||||||
|
@ -32,7 +32,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
|
||||||
// get clientlist from memory server.
|
// get clientlist from memory server.
|
||||||
std::vector<std::string> clients;
|
std::vector<std::string> clients;
|
||||||
|
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList, &clients);
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &clients);
|
||||||
|
|
||||||
size_t cur_clients_num = clients.size();
|
size_t cur_clients_num = clients.size();
|
||||||
std::string fl_id = get_exchange_keys_req->fl_id()->str();
|
std::string fl_id = get_exchange_keys_req->fl_id()->str();
|
||||||
|
@ -61,7 +61,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
|
||||||
|
|
||||||
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder) {
|
std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder) {
|
||||||
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
|
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
|
||||||
// step 0: judge if the input param is legal.
|
// step 0: judge if the input param is legal.
|
||||||
if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) {
|
if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) {
|
||||||
|
@ -75,8 +75,8 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
|
||||||
// step 1: get clientlist and client keys from memory server.
|
// step 1: get clientlist and client keys from memory server.
|
||||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
||||||
std::vector<std::string> client_list;
|
std::vector<std::string> client_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList, &client_list);
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
|
||||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys);
|
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||||
|
|
||||||
// step2: process new item data. and update new item data to memory server.
|
// step2: process new item data. and update new item data to memory server.
|
||||||
size_t cur_clients_num = client_list.size();
|
size_t cur_clients_num = client_list.size();
|
||||||
|
@ -131,9 +131,9 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
|
||||||
cur_public_key.push_back(spk);
|
cur_public_key.push_back(spk);
|
||||||
|
|
||||||
bool retcode_key =
|
bool retcode_key =
|
||||||
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(ps::server::kCtxClientsKeys, fl_id, cur_public_key);
|
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, fl_id, cur_public_key);
|
||||||
bool retcode_client =
|
bool retcode_client =
|
||||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxExChangeKeysClientList, fl_id);
|
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
|
||||||
if (retcode_key && retcode_client) {
|
if (retcode_key && retcode_client) {
|
||||||
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
|
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
|
||||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
|
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
|
||||||
|
@ -147,7 +147,7 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder,
|
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder,
|
||||||
const schema::ResponseCode retcode, const std::string &reason,
|
const schema::ResponseCode retcode, const std::string &reason,
|
||||||
const std::string &next_req_time, const int iteration) {
|
const std::string &next_req_time, const int iteration) {
|
||||||
auto rsp_reason = exchange_keys_resp_builder->CreateString(reason);
|
auto rsp_reason = exchange_keys_resp_builder->CreateString(reason);
|
||||||
|
@ -162,7 +162,7 @@ void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exc
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
bool CipherKeys::BuildGetKeys(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||||
const int iteration, const std::string &next_req_time, bool is_good) {
|
const int iteration, const std::string &next_req_time, bool is_good) {
|
||||||
bool flag = true;
|
bool flag = true;
|
||||||
if (is_good) {
|
if (is_good) {
|
||||||
|
@ -170,7 +170,7 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
|
||||||
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
|
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
|
||||||
MS_LOG(INFO) << "Get Keys: ";
|
MS_LOG(INFO) << "Get Keys: ";
|
||||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys);
|
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||||
if (record_public_keys.size() < cipher_init_->client_num_need_) {
|
if (record_public_keys.size() < cipher_init_->client_num_need_) {
|
||||||
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
|
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
|
||||||
<< "clients num: " << cipher_init_->client_num_need_;
|
<< "clients num: " << cipher_init_->client_num_need_;
|
||||||
|
@ -221,8 +221,8 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherKeys::ClearKeys() {
|
void CipherKeys::ClearKeys() {
|
||||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxExChangeKeysClientList);
|
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxExChangeKeysClientList);
|
||||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsKeys);
|
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsKeys);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace armour
|
} // namespace armour
|
||||||
|
|
|
@ -45,18 +45,18 @@ class CipherKeys {
|
||||||
// handle the client's request of get keys.
|
// handle the client's request of get keys.
|
||||||
bool GetKeys(const int cur_iterator, const std::string &next_req_time,
|
bool GetKeys(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::GetExchangeKeys *get_exchange_keys_req,
|
const schema::GetExchangeKeys *get_exchange_keys_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> get_exchange_keys_resp_builder);
|
std::shared_ptr<fl::server::FBBuilder> get_exchange_keys_resp_builder);
|
||||||
|
|
||||||
// handle the client's request of exchange keys.
|
// handle the client's request of exchange keys.
|
||||||
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder);
|
std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder);
|
||||||
|
|
||||||
// build response code of get keys.
|
// build response code of get keys.
|
||||||
bool BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode, const int iteration,
|
bool BuildGetKeys(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode, const int iteration,
|
||||||
const std::string &next_req_time, bool is_good);
|
const std::string &next_req_time, bool is_good);
|
||||||
// build response code of exchange keys.
|
// build response code of exchange keys.
|
||||||
void BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exchange_keys_resp_builder,
|
void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> exchange_keys_resp_builder,
|
||||||
const schema::ResponseCode retcode, const std::string &reason,
|
const schema::ResponseCode retcode, const std::string &reason,
|
||||||
const std::string &next_req_time, const int iteration);
|
const std::string &next_req_time, const int iteration);
|
||||||
// clear the shared memory.
|
// clear the shared memory.
|
||||||
|
|
|
@ -21,16 +21,16 @@ namespace armour {
|
||||||
|
|
||||||
void CipherMetaStorage::GetClientSharesFromServer(
|
void CipherMetaStorage::GetClientSharesFromServer(
|
||||||
const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) {
|
const char *list_name, std::map<std::string, std::vector<clientshare_str>> *clients_shares_list) {
|
||||||
const ps::PBMetadata &clients_shares_pb_out =
|
const fl::PBMetadata &clients_shares_pb_out =
|
||||||
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||||
const ps::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
|
const fl::ClientShares &clients_shares_pb = clients_shares_pb_out.client_shares();
|
||||||
auto iter = clients_shares_pb.client_secret_shares().begin();
|
auto iter = clients_shares_pb.client_secret_shares().begin();
|
||||||
for (; iter != clients_shares_pb.client_secret_shares().end(); ++iter) {
|
for (; iter != clients_shares_pb.client_secret_shares().end(); ++iter) {
|
||||||
std::string fl_id = iter->first;
|
std::string fl_id = iter->first;
|
||||||
const ps::SharesPb &shares_pb = iter->second;
|
const fl::SharesPb &shares_pb = iter->second;
|
||||||
std::vector<clientshare_str> encrpted_shares_new;
|
std::vector<clientshare_str> encrpted_shares_new;
|
||||||
for (int index_shares = 0; index_shares < shares_pb.clientsharestrs_size(); ++index_shares) {
|
for (int index_shares = 0; index_shares < shares_pb.clientsharestrs_size(); ++index_shares) {
|
||||||
const ps::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
|
const fl::ClientShareStr &client_share_str_pb = shares_pb.clientsharestrs(index_shares);
|
||||||
clientshare_str new_clientshare;
|
clientshare_str new_clientshare;
|
||||||
new_clientshare.fl_id = client_share_str_pb.fl_id();
|
new_clientshare.fl_id = client_share_str_pb.fl_id();
|
||||||
new_clientshare.index = client_share_str_pb.index();
|
new_clientshare.index = client_share_str_pb.index();
|
||||||
|
@ -42,8 +42,8 @@ void CipherMetaStorage::GetClientSharesFromServer(
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) {
|
void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list) {
|
||||||
const ps::PBMetadata &client_list_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
const fl::PBMetadata &client_list_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||||
const ps::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
const fl::UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
||||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
||||||
std::string fl_id = client_list_pb.fl_id(i);
|
std::string fl_id = client_list_pb.fl_id(i);
|
||||||
clients_list->push_back(fl_id);
|
clients_list->push_back(fl_id);
|
||||||
|
@ -52,14 +52,14 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect
|
||||||
|
|
||||||
void CipherMetaStorage::GetClientKeysFromServer(
|
void CipherMetaStorage::GetClientKeysFromServer(
|
||||||
const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) {
|
const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) {
|
||||||
const ps::PBMetadata &clients_keys_pb_out =
|
const fl::PBMetadata &clients_keys_pb_out =
|
||||||
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||||
const ps::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||||
|
|
||||||
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
|
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
|
||||||
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
|
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
|
||||||
std::string fl_id = iter->first;
|
std::string fl_id = iter->first;
|
||||||
ps::KeysPb keys_pb = iter->second;
|
fl::KeysPb keys_pb = iter->second;
|
||||||
std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||||
std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
||||||
std::vector<std::vector<unsigned char>> cur_keys;
|
std::vector<std::vector<unsigned char>> cur_keys;
|
||||||
|
@ -70,9 +70,9 @@ void CipherMetaStorage::GetClientKeysFromServer(
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) {
|
bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise) {
|
||||||
const ps::PBMetadata &clients_noises_pb_out =
|
const fl::PBMetadata &clients_noises_pb_out =
|
||||||
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||||
const ps::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||||
while (clients_noises_pb.has_one_client_noises() == false) {
|
while (clients_noises_pb.has_one_client_noises() == false) {
|
||||||
MS_LOG(INFO) << "GetClientNoisesFromServer NULL.";
|
MS_LOG(INFO) << "GetClientNoisesFromServer NULL.";
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||||
|
@ -83,8 +83,8 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) {
|
bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) {
|
||||||
const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
|
const fl::PBMetadata &prime_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
|
||||||
ps::Prime prime_pb(prime_pb_out.prime());
|
fl::Prime prime_pb(prime_pb_out.prime());
|
||||||
std::string str = *(prime_pb.mutable_prime());
|
std::string str = *(prime_pb.mutable_prime());
|
||||||
MS_LOG(INFO) << "get prime from metastorage :" << str;
|
MS_LOG(INFO) << "get prime from metastorage :" << str;
|
||||||
|
|
||||||
|
@ -99,20 +99,20 @@ bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char
|
||||||
|
|
||||||
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
|
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
|
||||||
bool retcode = true;
|
bool retcode = true;
|
||||||
ps::FLId fl_id_pb;
|
fl::FLId fl_id_pb;
|
||||||
fl_id_pb.set_fl_id(fl_id);
|
fl_id_pb.set_fl_id(fl_id);
|
||||||
ps::PBMetadata client_pb;
|
fl::PBMetadata client_pb;
|
||||||
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
|
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
|
||||||
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
|
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
|
||||||
return retcode;
|
return retcode;
|
||||||
}
|
}
|
||||||
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
|
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
|
||||||
MS_LOG(INFO) << "register prime: " << prime;
|
MS_LOG(INFO) << "register prime: " << prime;
|
||||||
ps::Prime prime_id_pb;
|
fl::Prime prime_id_pb;
|
||||||
prime_id_pb.set_prime(prime);
|
prime_id_pb.set_prime(prime);
|
||||||
ps::PBMetadata prime_pb;
|
fl::PBMetadata prime_pb;
|
||||||
prime_pb.mutable_prime()->MergeFrom(prime_id_pb);
|
prime_pb.mutable_prime()->MergeFrom(prime_id_pb);
|
||||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(list_name, prime_pb);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
||||||
|
@ -123,25 +123,25 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// update new item to memory server.
|
// update new item to memory server.
|
||||||
ps::KeysPb keys;
|
fl::KeysPb keys;
|
||||||
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
|
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
|
||||||
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
|
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
|
||||||
ps::PairClientKeys pair_client_keys_pb;
|
fl::PairClientKeys pair_client_keys_pb;
|
||||||
pair_client_keys_pb.set_fl_id(fl_id);
|
pair_client_keys_pb.set_fl_id(fl_id);
|
||||||
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
||||||
ps::PBMetadata client_and_keys_pb;
|
fl::PBMetadata client_and_keys_pb;
|
||||||
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
||||||
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||||
return retcode;
|
return retcode;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) {
|
bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) {
|
||||||
// update new item to memory server.
|
// update new item to memory server.
|
||||||
ps::OneClientNoises noises_pb;
|
fl::OneClientNoises noises_pb;
|
||||||
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
|
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
|
||||||
ps::PBMetadata client_noises_pb;
|
fl::PBMetadata client_noises_pb;
|
||||||
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
|
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
|
||||||
return ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
|
return fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherMetaStorage::UpdateClientShareToServer(
|
bool CipherMetaStorage::UpdateClientShareToServer(
|
||||||
|
@ -149,10 +149,10 @@ bool CipherMetaStorage::UpdateClientShareToServer(
|
||||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
|
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
|
||||||
bool retcode = true;
|
bool retcode = true;
|
||||||
int size_shares = shares->size();
|
int size_shares = shares->size();
|
||||||
ps::SharesPb shares_pb;
|
fl::SharesPb shares_pb;
|
||||||
for (int index = 0; index < size_shares; ++index) {
|
for (int index = 0; index < size_shares; ++index) {
|
||||||
// new item
|
// new item
|
||||||
ps::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
|
fl::ClientShareStr *client_share_str_new_p = shares_pb.add_clientsharestrs();
|
||||||
std::string fl_id_new = (*shares)[index]->fl_id()->str();
|
std::string fl_id_new = (*shares)[index]->fl_id()->str();
|
||||||
int index_new = (*shares)[index]->index();
|
int index_new = (*shares)[index]->index();
|
||||||
auto share = (*shares)[index]->share();
|
auto share = (*shares)[index]->share();
|
||||||
|
@ -160,32 +160,32 @@ bool CipherMetaStorage::UpdateClientShareToServer(
|
||||||
client_share_str_new_p->set_fl_id(fl_id_new);
|
client_share_str_new_p->set_fl_id(fl_id_new);
|
||||||
client_share_str_new_p->set_index(index_new);
|
client_share_str_new_p->set_index(index_new);
|
||||||
}
|
}
|
||||||
ps::PairClientShares pair_client_shares_pb;
|
fl::PairClientShares pair_client_shares_pb;
|
||||||
pair_client_shares_pb.set_fl_id(fl_id);
|
pair_client_shares_pb.set_fl_id(fl_id);
|
||||||
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
|
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
|
||||||
ps::PBMetadata client_and_shares_pb;
|
fl::PBMetadata client_and_shares_pb;
|
||||||
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
|
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
|
||||||
retcode = ps::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
|
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
|
||||||
return retcode;
|
return retcode;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherMetaStorage::RegisterClass() {
|
void CipherMetaStorage::RegisterClass() {
|
||||||
ps::PBMetadata exchange_kyes_client_list;
|
fl::PBMetadata exchange_kyes_client_list;
|
||||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxExChangeKeysClientList,
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
|
||||||
exchange_kyes_client_list);
|
exchange_kyes_client_list);
|
||||||
ps::PBMetadata clients_keys;
|
fl::PBMetadata clients_keys;
|
||||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsKeys, clients_keys);
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
|
||||||
ps::PBMetadata reconstruct_client_list;
|
fl::PBMetadata reconstruct_client_list;
|
||||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxReconstructClientList,
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxReconstructClientList,
|
||||||
reconstruct_client_list);
|
reconstruct_client_list);
|
||||||
ps::PBMetadata clients_reconstruct_shares;
|
fl::PBMetadata clients_reconstruct_shares;
|
||||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsReconstructShares,
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsReconstructShares,
|
||||||
clients_reconstruct_shares);
|
clients_reconstruct_shares);
|
||||||
ps::PBMetadata share_secretes_client_list;
|
fl::PBMetadata share_secretes_client_list;
|
||||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxShareSecretsClientList,
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
|
||||||
share_secretes_client_list);
|
share_secretes_client_list);
|
||||||
ps::PBMetadata clients_encrypt_shares;
|
fl::PBMetadata clients_encrypt_shares;
|
||||||
ps::server::DistributedMetadataStore::GetInstance().RegisterMetadata(ps::server::kCtxClientsEncryptedShares,
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
|
||||||
clients_encrypt_shares);
|
clients_encrypt_shares);
|
||||||
}
|
}
|
||||||
} // namespace armour
|
} // namespace armour
|
||||||
|
|
|
@ -101,15 +101,15 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
|
||||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
|
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
|
||||||
bool retcode = true;
|
bool retcode = true;
|
||||||
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori;
|
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsReconstructShares,
|
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
||||||
&reconstruct_secret_list_ori);
|
&reconstruct_secret_list_ori);
|
||||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(ps::server::kCtxClientsKeys, &record_public_keys);
|
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||||
std::vector<std::string> clients_reconstruct_list;
|
std::vector<std::string> clients_reconstruct_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxReconstructClientList,
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
|
||||||
&clients_reconstruct_list);
|
&clients_reconstruct_list);
|
||||||
std::vector<std::string> clients_share_list;
|
std::vector<std::string> clients_share_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList,
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||||
&clients_share_list);
|
&clients_share_list);
|
||||||
if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() ||
|
if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() ||
|
||||||
record_public_keys.size() < cipher_init_->client_num_need_ ||
|
record_public_keys.size() < cipher_init_->client_num_need_ ||
|
||||||
|
@ -146,7 +146,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
|
||||||
client_keys.clear();
|
client_keys.clear();
|
||||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
|
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
|
||||||
|
|
||||||
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(ps::server::kCtxClientNoises, noise) == false)
|
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise) == false)
|
||||||
return false;
|
return false;
|
||||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
|
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
|
||||||
} else {
|
} else {
|
||||||
|
@ -159,7 +159,7 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
|
||||||
// reconstruct secrets
|
// reconstruct secrets
|
||||||
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::SendReconstructSecret *reconstruct_secret_req,
|
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> reconstruct_secret_resp_builder,
|
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
|
||||||
const std::vector<std::string> &client_list) {
|
const std::vector<std::string> &client_list) {
|
||||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
|
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
|
@ -178,10 +178,10 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::vector<std::string> clients_reconstruct_list;
|
std::vector<std::string> clients_reconstruct_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxReconstructClientList,
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
|
||||||
&clients_reconstruct_list);
|
&clients_reconstruct_list);
|
||||||
std::map<std::string, std::vector<clientshare_str>> clients_shares_all;
|
std::map<std::string, std::vector<clientshare_str>> clients_shares_all;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsReconstructShares,
|
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
||||||
&clients_shares_all);
|
&clients_shares_all);
|
||||||
|
|
||||||
size_t count_client_num = clients_shares_all.size();
|
size_t count_client_num = clients_shares_all.size();
|
||||||
|
@ -215,9 +215,9 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
|
||||||
}
|
}
|
||||||
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
|
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
|
||||||
bool retcode_client =
|
bool retcode_client =
|
||||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxReconstructClientList, fl_id);
|
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id);
|
||||||
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
||||||
ps::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
|
fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
|
||||||
if (!(retcode_client && retcode_share)) {
|
if (!(retcode_client && retcode_share)) {
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
||||||
"reconstruct update shares or client failed.", cur_iterator, next_req_time);
|
"reconstruct update shares or client failed.", cur_iterator, next_req_time);
|
||||||
|
@ -233,9 +233,9 @@ bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::st
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
bool retcode_result = true;
|
bool retcode_result = true;
|
||||||
const ps::PBMetadata &clients_noises_pb_out =
|
const fl::PBMetadata &clients_noises_pb_out =
|
||||||
ps::server::DistributedMetadataStore::GetInstance().GetMetadata(ps::server::kCtxClientNoises);
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
|
||||||
const ps::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||||
if (clients_noises_pb.has_one_client_noises() == false) {
|
if (clients_noises_pb.has_one_client_noises() == false) {
|
||||||
MS_LOG(INFO) << "Success,the secret will be reconstructed.";
|
MS_LOG(INFO) << "Success,the secret will be reconstructed.";
|
||||||
retcode_result = ReconstructSecretsGenNoise(client_list);
|
retcode_result = ReconstructSecretsGenNoise(client_list);
|
||||||
|
@ -279,13 +279,13 @@ bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
|
||||||
|
|
||||||
void CipherReconStruct::ClearReconstructSecrets() {
|
void CipherReconStruct::ClearReconstructSecrets() {
|
||||||
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets START";
|
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets START";
|
||||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxReconstructClientList);
|
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxReconstructClientList);
|
||||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsReconstructShares);
|
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsReconstructShares);
|
||||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientNoises);
|
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientNoises);
|
||||||
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success";
|
MS_LOG(INFO) << "CipherReconStruct::ClearReconstructSecrets Success";
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<ps::server::FBBuilder> fbb,
|
void CipherReconStruct::BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb,
|
||||||
const schema::ResponseCode retcode, const std::string &reason,
|
const schema::ResponseCode retcode, const std::string &reason,
|
||||||
const int iteration, const std::string &next_req_time) {
|
const int iteration, const std::string &next_req_time) {
|
||||||
auto fbs_reason = fbb->CreateString(reason);
|
auto fbs_reason = fbb->CreateString(reason);
|
||||||
|
|
|
@ -44,11 +44,11 @@ class CipherReconStruct {
|
||||||
// reconstruct secret mask
|
// reconstruct secret mask
|
||||||
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::SendReconstructSecret *reconstruct_secret_req,
|
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> reconstruct_secret_resp_builder,
|
std::shared_ptr<fl::server::FBBuilder> reconstruct_secret_resp_builder,
|
||||||
const std::vector<std::string> &client_list);
|
const std::vector<std::string> &client_list);
|
||||||
|
|
||||||
// build response code of reconstruct secret.
|
// build response code of reconstruct secret.
|
||||||
void BuildReconstructSecretsRsp(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
void BuildReconstructSecretsRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||||
const std::string &reason, const int iteration, const std::string &next_req_time);
|
const std::string &reason, const int iteration, const std::string &next_req_time);
|
||||||
|
|
||||||
// clear the shared memory.
|
// clear the shared memory.
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
|
bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder,
|
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
|
||||||
const string next_req_time) {
|
const string next_req_time) {
|
||||||
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
|
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
|
||||||
if (share_secrets_req == nullptr) {
|
if (share_secrets_req == nullptr) {
|
||||||
|
@ -35,13 +35,13 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
||||||
// step 1: get client list and share secrets from memory server.
|
// step 1: get client list and share secrets from memory server.
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
std::vector<std::string> clients_share_list;
|
std::vector<std::string> clients_share_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList,
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||||
&clients_share_list);
|
&clients_share_list);
|
||||||
std::vector<std::string> clients_exchange_list;
|
std::vector<std::string> clients_exchange_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxExChangeKeysClientList,
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList,
|
||||||
&clients_exchange_list);
|
&clients_exchange_list);
|
||||||
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsEncryptedShares,
|
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
||||||
&encrypted_shares_all);
|
&encrypted_shares_all);
|
||||||
|
|
||||||
MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size()
|
MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size()
|
||||||
|
@ -75,9 +75,9 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
||||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
|
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
|
||||||
(share_secrets_req->encrypted_shares());
|
(share_secrets_req->encrypted_shares());
|
||||||
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
||||||
ps::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
|
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
|
||||||
bool retcode_client =
|
bool retcode_client =
|
||||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxShareSecretsClientList, fl_id_src);
|
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src);
|
||||||
bool retcode = retcode_share && retcode_client;
|
bool retcode = retcode_share && retcode_client;
|
||||||
if (retcode) {
|
if (retcode) {
|
||||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
|
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
|
||||||
|
@ -95,7 +95,7 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder,
|
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
|
||||||
const std::string &next_req_time) {
|
const std::string &next_req_time) {
|
||||||
MS_LOG(INFO) << "CipherShares::GetSecrets START";
|
MS_LOG(INFO) << "CipherShares::GetSecrets START";
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
|
@ -108,10 +108,10 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||||
|
|
||||||
// step 1: get client list and client shares list from memory server.
|
// step 1: get client list and client shares list from memory server.
|
||||||
std::vector<std::string> clients_share_list;
|
std::vector<std::string> clients_share_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(ps::server::kCtxShareSecretsClientList,
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||||
&clients_share_list);
|
&clients_share_list);
|
||||||
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(ps::server::kCtxClientsEncryptedShares,
|
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
||||||
&encrypted_shares_all);
|
&encrypted_shares_all);
|
||||||
int iteration = get_secrets_req->iteration();
|
int iteration = get_secrets_req->iteration();
|
||||||
size_t share_clients_num = clients_share_list.size();
|
size_t share_clients_num = clients_share_list.size();
|
||||||
|
@ -180,7 +180,7 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherShares::BuildGetSecretsRsp(
|
void CipherShares::BuildGetSecretsRsp(
|
||||||
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
|
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, schema::ResponseCode retcode, int iteration,
|
||||||
std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) {
|
std::string next_req_time, std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares) {
|
||||||
int rsp_retcode = retcode;
|
int rsp_retcode = retcode;
|
||||||
int rsp_iteration = iteration;
|
int rsp_iteration = iteration;
|
||||||
|
@ -199,7 +199,7 @@ void CipherShares::BuildGetSecretsRsp(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherShares::BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder,
|
void CipherShares::BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
|
||||||
const schema::ResponseCode retcode, const string &reason,
|
const schema::ResponseCode retcode, const string &reason,
|
||||||
const string &next_req_time, const int iteration) {
|
const string &next_req_time, const int iteration) {
|
||||||
auto rsp_reason = share_secrets_resp_builder->CreateString(reason);
|
auto rsp_reason = share_secrets_resp_builder->CreateString(reason);
|
||||||
|
@ -211,8 +211,8 @@ void CipherShares::BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> s
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherShares::ClearShareSecrets() {
|
void CipherShares::ClearShareSecrets() {
|
||||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxShareSecretsClientList);
|
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxShareSecretsClientList);
|
||||||
ps::server::DistributedMetadataStore::GetInstance().ResetMetadata(ps::server::kCtxClientsEncryptedShares);
|
fl::server::DistributedMetadataStore::GetInstance().ResetMetadata(fl::server::kCtxClientsEncryptedShares);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace armour
|
} // namespace armour
|
||||||
|
|
|
@ -43,17 +43,17 @@ class CipherShares {
|
||||||
|
|
||||||
// handle the client's request of share secrets.
|
// handle the client's request of share secrets.
|
||||||
bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
|
bool ShareSecrets(const int cur_iterator, const schema::RequestShareSecrets *share_secrets_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder, const string next_req_time);
|
std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder, const string next_req_time);
|
||||||
// handle the client's request of get secrets.
|
// handle the client's request of get secrets.
|
||||||
bool GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
bool GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||||
std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time);
|
std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder, const std::string &next_req_time);
|
||||||
|
|
||||||
// build response code of share secrets.
|
// build response code of share secrets.
|
||||||
void BuildShareSecretsRsp(std::shared_ptr<ps::server::FBBuilder> share_secrets_resp_builder,
|
void BuildShareSecretsRsp(std::shared_ptr<fl::server::FBBuilder> share_secrets_resp_builder,
|
||||||
const schema::ResponseCode retcode, const string &reason, const string &next_req_time,
|
const schema::ResponseCode retcode, const string &reason, const string &next_req_time,
|
||||||
const int iteration);
|
const int iteration);
|
||||||
// build response code of get secrets.
|
// build response code of get secrets.
|
||||||
void BuildGetSecretsRsp(std::shared_ptr<ps::server::FBBuilder> get_secrets_resp_builder,
|
void BuildGetSecretsRsp(std::shared_ptr<fl::server::FBBuilder> get_secrets_resp_builder,
|
||||||
const schema::ResponseCode retcode, const int iteration, std::string next_req_time,
|
const schema::ResponseCode retcode, const int iteration, std::string next_req_time,
|
||||||
std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares);
|
std::vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares);
|
||||||
// clear the shared memory.
|
// clear the shared memory.
|
||||||
|
|
|
@ -26,13 +26,13 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
std::vector<float> noise;
|
std::vector<float> noise;
|
||||||
|
|
||||||
cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(ps::server::kCtxClientNoises, &noise);
|
cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
|
||||||
if (noise.size() != cipher_init_->featuremap_) {
|
if (noise.size() != cipher_init_->featuremap_) {
|
||||||
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
|
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t data_size = ps::server::LocalMetaStore::GetInstance().value<size_t>(ps::server::kCtxFedAvgTotalDataSize);
|
size_t data_size = fl::server::LocalMetaStore::GetInstance().value<size_t>(fl::server::kCtxFedAvgTotalDataSize);
|
||||||
int sum_size = 0;
|
int sum_size = 0;
|
||||||
for (auto iter = data.begin(); iter != data.end(); ++iter) {
|
for (auto iter = data.begin(); iter != data.end(); ++iter) {
|
||||||
int size_data = iter->second->size / sizeof(float);
|
int size_data = iter->second->size / sizeof(float);
|
||||||
|
|
|
@ -17,13 +17,13 @@
|
||||||
#include "fl/server/collective_ops_impl.h"
|
#include "fl/server/collective_ops_impl.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
|
void CollectiveOpsImpl::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
|
||||||
MS_EXCEPTION_IF_NULL(server_node);
|
MS_EXCEPTION_IF_NULL(server_node);
|
||||||
server_node_ = server_node;
|
server_node_ = server_node;
|
||||||
local_rank_ = server_node_->rank_id();
|
local_rank_ = server_node_->rank_id();
|
||||||
server_num_ = PSContext::instance()->initial_server_num();
|
server_num_ = ps::PSContext::instance()->initial_server_num();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
||||||
// Step 1: Async send data to next rank.
|
// Step 1: Async send data to next rank.
|
||||||
size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
||||||
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
||||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
|
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||||
chunk_sizes[send_chunk_index] * sizeof(T));
|
chunk_sizes[send_chunk_index] * sizeof(T));
|
||||||
// Step 2: Async receive data to next rank and wait until it's done.
|
// Step 2: Async receive data to next rank and wait until it's done.
|
||||||
size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size;
|
size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size;
|
||||||
|
@ -76,7 +76,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
||||||
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
||||||
|
|
||||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -104,7 +104,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
||||||
for (size_t i = 0; i < rank_size - 1; i++) {
|
for (size_t i = 0; i < rank_size - 1; i++) {
|
||||||
size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size;
|
size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size;
|
||||||
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
T *send_chunk = output_buff + chunk_offset[send_chunk_index];
|
||||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk,
|
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, send_to_rank, send_chunk,
|
||||||
chunk_sizes[send_chunk_index] * sizeof(T));
|
chunk_sizes[send_chunk_index] * sizeof(T));
|
||||||
size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size;
|
||||||
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
|
T *recv_chunk = output_buff + chunk_offset[recv_chunk_index];
|
||||||
|
@ -113,7 +113,7 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
||||||
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
<< ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i;
|
||||||
|
|
||||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, recv_from_rank, &recv_str);
|
||||||
|
|
||||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||||
|
@ -151,7 +151,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
||||||
for (uint32_t i = 1; i < rank_size; i++) {
|
for (uint32_t i = 1; i < rank_size; i++) {
|
||||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
|
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
|
||||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str);
|
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
|
||||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -167,7 +167,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
|
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
|
||||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
|
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
|
||||||
if (!server_node_->Wait(send_req_id)) {
|
if (!server_node_->Wait(send_req_id)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -180,7 +180,8 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
||||||
if (local_rank_ == 0) {
|
if (local_rank_ == 0) {
|
||||||
for (uint32_t i = 1; i < rank_size; i++) {
|
for (uint32_t i = 1; i < rank_size; i++) {
|
||||||
MS_LOG(DEBUG) << "Broadcast data to process " << i;
|
MS_LOG(DEBUG) << "Broadcast data to process " << i;
|
||||||
auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
|
auto send_req_id =
|
||||||
|
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
|
||||||
if (!server_node_->Wait(send_req_id)) {
|
if (!server_node_->Wait(send_req_id)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -189,7 +190,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
|
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
|
||||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str);
|
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
|
||||||
if (!server_node_->CollectiveWait(recv_req_id)) {
|
if (!server_node_->CollectiveWait(recv_req_id)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -247,5 +248,5 @@ template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *re
|
||||||
template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -27,7 +27,7 @@
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// CollectiveOpsImpl is the collective communication API of the server.
|
// CollectiveOpsImpl is the collective communication API of the server.
|
||||||
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
|
// For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also
|
||||||
|
@ -39,7 +39,7 @@ class CollectiveOpsImpl {
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
|
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
bool AllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
@ -48,7 +48,7 @@ class CollectiveOpsImpl {
|
||||||
bool ReInitForScaling();
|
bool ReInitForScaling();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CollectiveOpsImpl() = default;
|
CollectiveOpsImpl() : server_node_(nullptr), local_rank_(0), server_num_(0) {}
|
||||||
~CollectiveOpsImpl() = default;
|
~CollectiveOpsImpl() = default;
|
||||||
CollectiveOpsImpl(const CollectiveOpsImpl &) = delete;
|
CollectiveOpsImpl(const CollectiveOpsImpl &) = delete;
|
||||||
CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete;
|
CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete;
|
||||||
|
@ -61,7 +61,7 @@ class CollectiveOpsImpl {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count);
|
||||||
|
|
||||||
std::shared_ptr<core::ServerNode> server_node_;
|
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||||
uint32_t local_rank_;
|
uint32_t local_rank_;
|
||||||
uint32_t server_num_;
|
uint32_t server_num_;
|
||||||
|
|
||||||
|
@ -69,6 +69,6 @@ class CollectiveOpsImpl {
|
||||||
std::mutex mtx_;
|
std::mutex mtx_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_COLLECTIVE_OPS_IMPL_H_
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -37,7 +37,7 @@
|
||||||
#include "ps/core/communicator/message_handler.h"
|
#include "ps/core/communicator/message_handler.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// Definitions for the server framework.
|
// Definitions for the server framework.
|
||||||
enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
|
enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER };
|
||||||
|
@ -73,7 +73,7 @@ using TimeOutCb = std::function<void(bool, const std::string &)>;
|
||||||
using StopTimerCb = std::function<void(void)>;
|
using StopTimerCb = std::function<void(void)>;
|
||||||
using FinishIterCb = std::function<void(bool, const std::string &)>;
|
using FinishIterCb = std::function<void(bool, const std::string &)>;
|
||||||
using FinalizeCb = std::function<void(void)>;
|
using FinalizeCb = std::function<void(void)>;
|
||||||
using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>;
|
using MessageCallback = std::function<void(const std::shared_ptr<ps::core::MessageHandler> &)>;
|
||||||
|
|
||||||
// Information about whether server kernel will reuse kernel node memory from the front end.
|
// Information about whether server kernel will reuse kernel node memory from the front end.
|
||||||
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
|
// Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate".
|
||||||
|
@ -237,6 +237,6 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size
|
||||||
// Definitions for Parameter Server.
|
// Definitions for Parameter Server.
|
||||||
|
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_COMMON_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_COMMON_H_
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "fl/server/consistent_hash_ring.h"
|
#include "fl/server/consistent_hash_ring.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
bool ConsistentHashRing::Insert(uint32_t rank) {
|
bool ConsistentHashRing::Insert(uint32_t rank) {
|
||||||
for (uint32_t i = 0; i < virtual_node_num_; i++) {
|
for (uint32_t i = 0; i < virtual_node_num_; i++) {
|
||||||
|
@ -53,5 +53,5 @@ uint32_t ConsistentHashRing::Find(const std::string &key) {
|
||||||
return iterator->second;
|
return iterator->second;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,15 +14,15 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in
|
// To support distributed storage and make servers easy to scale-out and scale-in for a large load of metadata in
|
||||||
// server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node.
|
// server, we use class ConsistentHashRing to help servers find out which metadata is stored in which server node.
|
||||||
|
@ -59,6 +59,6 @@ class ConsistentHashRing {
|
||||||
std::map<size_t, uint32_t> ring_;
|
std::map<size_t, uint32_t> ring_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_CONSISTENT_HASH_RING_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_CONSISTENT_HASH_RING_H_
|
||||||
|
|
|
@ -20,19 +20,19 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node,
|
void DistributedCountService::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node,
|
||||||
uint32_t counting_server_rank) {
|
uint32_t counting_server_rank) {
|
||||||
server_node_ = server_node;
|
server_node_ = server_node;
|
||||||
MS_EXCEPTION_IF_NULL(server_node_);
|
MS_EXCEPTION_IF_NULL(server_node_);
|
||||||
local_rank_ = server_node_->rank_id();
|
local_rank_ = server_node_->rank_id();
|
||||||
server_num_ = PSContext::instance()->initial_server_num();
|
server_num_ = ps::PSContext::instance()->initial_server_num();
|
||||||
counting_server_rank_ = counting_server_rank;
|
counting_server_rank_ = counting_server_rank;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
void DistributedCountService::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||||
communicator_ = communicator;
|
communicator_ = communicator;
|
||||||
MS_EXCEPTION_IF_NULL(communicator_);
|
MS_EXCEPTION_IF_NULL(communicator_);
|
||||||
communicator_->RegisterMsgCallBack(
|
communicator_->RegisterMsgCallBack(
|
||||||
|
@ -94,7 +94,7 @@ bool DistributedCountService::Count(const std::string &name, const std::string &
|
||||||
report_count_req.set_id(id);
|
report_count_req.set_id(id);
|
||||||
|
|
||||||
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
|
std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr;
|
||||||
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount,
|
if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, ps::core::TcpUserCommand::kCount,
|
||||||
&report_cnt_rsp_msg)) {
|
&report_cnt_rsp_msg)) {
|
||||||
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
|
MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name;
|
||||||
return false;
|
return false;
|
||||||
|
@ -126,7 +126,7 @@ bool DistributedCountService::CountReachThreshold(const std::string &name) {
|
||||||
|
|
||||||
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
|
std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr;
|
||||||
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
|
if (!communicator_->SendPbRequest(count_reach_threshold_req, counting_server_rank_,
|
||||||
core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
|
ps::core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) {
|
||||||
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
|
MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -165,7 +165,7 @@ bool DistributedCountService::ReInitForScaling() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
void DistributedCountService::HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
MS_LOG(ERROR) << "Message is nullptr.";
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
return;
|
return;
|
||||||
|
@ -214,7 +214,8 @@ void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::Mes
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
void DistributedCountService::HandleCountReachThresholdRequest(
|
||||||
|
const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
MS_LOG(ERROR) << "Message is nullptr.";
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
return;
|
return;
|
||||||
|
@ -237,7 +238,7 @@ void DistributedCountService::HandleCountReachThresholdRequest(const std::shared
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
void DistributedCountService::HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
MS_LOG(ERROR) << "Message is nullptr.";
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
return;
|
return;
|
||||||
|
@ -290,7 +291,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name) {
|
||||||
|
|
||||||
// Broadcast to all follower servers.
|
// Broadcast to all follower servers.
|
||||||
for (uint32_t i = 1; i < server_num_; i++) {
|
for (uint32_t i = 1; i < server_num_; i++) {
|
||||||
if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
||||||
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
|
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -308,7 +309,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
||||||
|
|
||||||
// Broadcast to all follower servers.
|
// Broadcast to all follower servers.
|
||||||
for (uint32_t i = 1; i < server_num_; i++) {
|
for (uint32_t i = 1; i < server_num_; i++) {
|
||||||
if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) {
|
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
||||||
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
|
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -318,5 +319,5 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -27,7 +27,7 @@
|
||||||
#include "ps/core/communicator/tcp_communicator.h"
|
#include "ps/core/communicator/tcp_communicator.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
constexpr uint32_t kDefaultCountingServerRank = 0;
|
constexpr uint32_t kDefaultCountingServerRank = 0;
|
||||||
constexpr auto kModuleDistributedCountService = "DistributedCountService";
|
constexpr auto kModuleDistributedCountService = "DistributedCountService";
|
||||||
|
@ -54,10 +54,10 @@ class DistributedCountService {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize counter service with the server node because communication is needed.
|
// Initialize counter service with the server node because communication is needed.
|
||||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank);
|
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node, uint32_t counting_server_rank);
|
||||||
|
|
||||||
// Register message callbacks of the counting server to handle messages sent by the other servers.
|
// Register message callbacks of the counting server to handle messages sent by the other servers.
|
||||||
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
||||||
|
|
||||||
// Register counter to the counting server for the name with its threshold count in server cluster dimension and
|
// Register counter to the counting server for the name with its threshold count in server cluster dimension and
|
||||||
// first/last count event callbacks.
|
// first/last count event callbacks.
|
||||||
|
@ -87,15 +87,15 @@ class DistributedCountService {
|
||||||
DistributedCountService &operator=(const DistributedCountService &) = delete;
|
DistributedCountService &operator=(const DistributedCountService &) = delete;
|
||||||
|
|
||||||
// Callback for the reporting count message from other servers. Only counting server will call this method.
|
// Callback for the reporting count message from other servers. Only counting server will call this method.
|
||||||
void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message);
|
void HandleCountRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// Callback for the querying whether threshold count is reached message from other servers. Only counting
|
// Callback for the querying whether threshold count is reached message from other servers. Only counting
|
||||||
// server will call this method.
|
// server will call this method.
|
||||||
void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message);
|
void HandleCountReachThresholdRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// Callback for the first/last event message from the counting server. Only other servers will call this
|
// Callback for the first/last event message from the counting server. Only other servers will call this
|
||||||
// method.
|
// method.
|
||||||
void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message);
|
void HandleCounterEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// Call the callbacks when the first/last count event is triggered.
|
// Call the callbacks when the first/last count event is triggered.
|
||||||
bool TriggerCounterEvent(const std::string &name);
|
bool TriggerCounterEvent(const std::string &name);
|
||||||
|
@ -103,8 +103,8 @@ class DistributedCountService {
|
||||||
bool TriggerLastCountEvent(const std::string &name);
|
bool TriggerLastCountEvent(const std::string &name);
|
||||||
|
|
||||||
// Members for the communication between counting server and other servers.
|
// Members for the communication between counting server and other servers.
|
||||||
std::shared_ptr<core::ServerNode> server_node_;
|
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||||
std::shared_ptr<core::TcpCommunicator> communicator_;
|
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
|
||||||
uint32_t local_rank_;
|
uint32_t local_rank_;
|
||||||
uint32_t server_num_;
|
uint32_t server_num_;
|
||||||
|
|
||||||
|
@ -126,6 +126,6 @@ class DistributedCountService {
|
||||||
std::unordered_map<std::string, std::mutex> mutex_;
|
std::unordered_map<std::string, std::mutex> mutex_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_COUNT_SERVICE_H_
|
||||||
|
|
|
@ -20,18 +20,18 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) {
|
void DistributedMetadataStore::Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node) {
|
||||||
server_node_ = server_node;
|
server_node_ = server_node;
|
||||||
MS_EXCEPTION_IF_NULL(server_node);
|
MS_EXCEPTION_IF_NULL(server_node);
|
||||||
local_rank_ = server_node_->rank_id();
|
local_rank_ = server_node_->rank_id();
|
||||||
server_num_ = PSContext::instance()->initial_server_num();
|
server_num_ = ps::PSContext::instance()->initial_server_num();
|
||||||
InitHashRing();
|
InitHashRing();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
void DistributedMetadataStore::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||||
communicator_ = communicator;
|
communicator_ = communicator;
|
||||||
MS_EXCEPTION_IF_NULL(communicator_);
|
MS_EXCEPTION_IF_NULL(communicator_);
|
||||||
communicator_->RegisterMsgCallBack(
|
communicator_->RegisterMsgCallBack(
|
||||||
|
@ -100,7 +100,7 @@ bool DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBM
|
||||||
metadata_with_name.set_name(name);
|
metadata_with_name.set_name(name);
|
||||||
*metadata_with_name.mutable_metadata() = meta;
|
*metadata_with_name.mutable_metadata() = meta;
|
||||||
std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
|
std::shared_ptr<std::vector<unsigned char>> update_meta_rsp_msg = nullptr;
|
||||||
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata,
|
if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, ps::core::TcpUserCommand::kUpdateMetadata,
|
||||||
&update_meta_rsp_msg)) {
|
&update_meta_rsp_msg)) {
|
||||||
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
|
MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -133,7 +133,7 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
||||||
PBMetadata get_metadata_rsp;
|
PBMetadata get_metadata_rsp;
|
||||||
|
|
||||||
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
|
std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr;
|
||||||
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata,
|
if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, ps::core::TcpUserCommand::kGetMetadata,
|
||||||
&get_meta_rsp_msg)) {
|
&get_meta_rsp_msg)) {
|
||||||
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
|
MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed.";
|
||||||
return get_metadata_rsp;
|
return get_metadata_rsp;
|
||||||
|
@ -174,7 +174,7 @@ void DistributedMetadataStore::InitHashRing() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
MS_LOG(ERROR) << "Message is nullptr.";
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
return;
|
return;
|
||||||
|
@ -196,7 +196,7 @@ void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
MS_LOG(ERROR) << "Message is nullptr.";
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
return;
|
return;
|
||||||
|
@ -267,7 +267,7 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
|
||||||
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
|
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
|
||||||
auto &fl_id = meta.pair_client_shares().fl_id();
|
auto &fl_id = meta.pair_client_shares().fl_id();
|
||||||
auto &client_shares = meta.pair_client_shares().client_shares();
|
auto &client_shares = meta.pair_client_shares().client_shares();
|
||||||
// google::protobuf::Map< std::string, mindspore::ps::core::SharesPb >::const_iterator iter;
|
// google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
|
||||||
// Check whether the new item already exists.
|
// Check whether the new item already exists.
|
||||||
bool add_flag = true;
|
bool add_flag = true;
|
||||||
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); iter++) {
|
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); iter++) {
|
||||||
|
@ -299,5 +299,5 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -27,7 +27,7 @@
|
||||||
#include "fl/server/consistent_hash_ring.h"
|
#include "fl/server/consistent_hash_ring.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
constexpr auto kModuleDistributedMetadataStore = "DistributedMetadataStore";
|
constexpr auto kModuleDistributedMetadataStore = "DistributedMetadataStore";
|
||||||
// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly
|
// This class is used for distributed metadata storage using consistent hash. All metadata is distributedly
|
||||||
|
@ -44,10 +44,10 @@ class DistributedMetadataStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize metadata storage with the server node because communication is needed.
|
// Initialize metadata storage with the server node because communication is needed.
|
||||||
void Initialize(const std::shared_ptr<core::ServerNode> &server_node);
|
void Initialize(const std::shared_ptr<ps::core::ServerNode> &server_node);
|
||||||
|
|
||||||
// Register callbacks for the server to handle update/get metadata messages from other servers.
|
// Register callbacks for the server to handle update/get metadata messages from other servers.
|
||||||
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
||||||
|
|
||||||
// Register metadata for the name with the initial value. This method should be only called once for each name.
|
// Register metadata for the name with the initial value. This method should be only called once for each name.
|
||||||
void RegisterMetadata(const std::string &name, const PBMetadata &meta);
|
void RegisterMetadata(const std::string &name, const PBMetadata &meta);
|
||||||
|
@ -65,7 +65,13 @@ class DistributedMetadataStore {
|
||||||
bool ReInitForScaling();
|
bool ReInitForScaling();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DistributedMetadataStore() = default;
|
DistributedMetadataStore()
|
||||||
|
: server_node_(nullptr),
|
||||||
|
communicator_(nullptr),
|
||||||
|
local_rank_(0),
|
||||||
|
server_num_(0),
|
||||||
|
router_(nullptr),
|
||||||
|
metadata_({}) {}
|
||||||
~DistributedMetadataStore() = default;
|
~DistributedMetadataStore() = default;
|
||||||
DistributedMetadataStore(const DistributedMetadataStore &) = delete;
|
DistributedMetadataStore(const DistributedMetadataStore &) = delete;
|
||||||
DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete;
|
DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete;
|
||||||
|
@ -74,17 +80,17 @@ class DistributedMetadataStore {
|
||||||
void InitHashRing();
|
void InitHashRing();
|
||||||
|
|
||||||
// Callback for updating metadata request sent to the server.
|
// Callback for updating metadata request sent to the server.
|
||||||
void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
|
void HandleUpdateMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// Callback for getting metadata request sent to the server.
|
// Callback for getting metadata request sent to the server.
|
||||||
void HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message);
|
void HandleGetMetadataRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// Do updating metadata in the server where the metadata for the name is stored.
|
// Do updating metadata in the server where the metadata for the name is stored.
|
||||||
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
|
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||||
|
|
||||||
// Members for the communication between servers.
|
// Members for the communication between servers.
|
||||||
std::shared_ptr<core::ServerNode> server_node_;
|
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||||
std::shared_ptr<core::TcpCommunicator> communicator_;
|
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
|
||||||
uint32_t local_rank_;
|
uint32_t local_rank_;
|
||||||
uint32_t server_num_;
|
uint32_t server_num_;
|
||||||
|
|
||||||
|
@ -100,6 +106,6 @@ class DistributedMetadataStore {
|
||||||
std::unordered_map<std::string, std::mutex> mutex_;
|
std::unordered_map<std::string, std::mutex> mutex_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_DISTRIBUTED_META_STORE_H_
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
|
void Executor::Initialize(const FuncGraphPtr &func_graph, size_t aggregation_count) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
@ -320,5 +320,5 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
@ -31,7 +31,7 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
|
// Executor is the entrance for server to handle aggregation, optimizing, model querying, etc. It handles
|
||||||
// logics relevant to kernel launching.
|
// logics relevant to kernel launching.
|
||||||
|
@ -94,7 +94,7 @@ class Executor {
|
||||||
bool Unmask();
|
bool Unmask();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Executor() {}
|
Executor() : initialized_(false), aggregation_count_(0), param_names_({}), param_aggrs_({}) {}
|
||||||
~Executor() = default;
|
~Executor() = default;
|
||||||
Executor(const Executor &) = delete;
|
Executor(const Executor &) = delete;
|
||||||
Executor &operator=(const Executor &) = delete;
|
Executor &operator=(const Executor &) = delete;
|
||||||
|
@ -126,6 +126,6 @@ class Executor {
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_EXECUTOR_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_EXECUTOR_H_
|
||||||
|
|
|
@ -23,10 +23,10 @@
|
||||||
#include "fl/server/server.h"
|
#include "fl/server/server.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
class Server;
|
class Server;
|
||||||
void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||||
MS_EXCEPTION_IF_NULL(communicator);
|
MS_EXCEPTION_IF_NULL(communicator);
|
||||||
communicator_ = communicator;
|
communicator_ = communicator;
|
||||||
communicator_->RegisterMsgCallBack("syncIteration",
|
communicator_->RegisterMsgCallBack("syncIteration",
|
||||||
|
@ -42,12 +42,12 @@ void Iteration::RegisterMessageCallback(const std::shared_ptr<core::TcpCommunica
|
||||||
std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1));
|
std::bind(&Iteration::HandleEndLastIterRequest, this, std::placeholders::_1));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node) {
|
void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) {
|
||||||
MS_EXCEPTION_IF_NULL(server_node);
|
MS_EXCEPTION_IF_NULL(server_node);
|
||||||
server_node_ = server_node;
|
server_node_ = server_node;
|
||||||
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning),
|
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
|
||||||
std::bind(&Iteration::HandleIterationRunningEvent, this));
|
std::bind(&Iteration::HandleIterationRunningEvent, this));
|
||||||
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted),
|
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
|
||||||
std::bind(&Iteration::HandleIterationCompletedEvent, this));
|
std::bind(&Iteration::HandleIterationCompletedEvent, this));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,7 +56,7 @@ void Iteration::AddRound(const std::shared_ptr<Round> &round) {
|
||||||
rounds_.push_back(round);
|
rounds_.push_back(round);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
|
||||||
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
|
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) {
|
||||||
if (communicators.empty()) {
|
if (communicators.empty()) {
|
||||||
MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
|
MS_LOG(EXCEPTION) << "Communicators for rounds is empty.";
|
||||||
|
@ -64,7 +64,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorB
|
||||||
}
|
}
|
||||||
|
|
||||||
std::for_each(communicators.begin(), communicators.end(),
|
std::for_each(communicators.begin(), communicators.end(),
|
||||||
[&](const std::shared_ptr<core::CommunicatorBase> &communicator) {
|
[&](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) {
|
||||||
for (auto &round : rounds_) {
|
for (auto &round : rounds_) {
|
||||||
if (round == nullptr) {
|
if (round == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -120,7 +120,7 @@ void Iteration::SetIterationRunning() {
|
||||||
}
|
}
|
||||||
if (server_node_->rank_id() == kLeaderServerRank) {
|
if (server_node_->rank_id() == kLeaderServerRank) {
|
||||||
// This event helps worker/server to be consistent in iteration state.
|
// This event helps worker/server to be consistent in iteration state.
|
||||||
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationRunning));
|
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning));
|
||||||
}
|
}
|
||||||
iteration_state_ = IterationState::kRunning;
|
iteration_state_ = IterationState::kRunning;
|
||||||
}
|
}
|
||||||
|
@ -133,7 +133,7 @@ void Iteration::SetIterationCompleted() {
|
||||||
}
|
}
|
||||||
if (server_node_->rank_id() == kLeaderServerRank) {
|
if (server_node_->rank_id() == kLeaderServerRank) {
|
||||||
// This event helps worker/server to be consistent in iteration state.
|
// This event helps worker/server to be consistent in iteration state.
|
||||||
server_node_->BroadcastEvent(static_cast<uint32_t>(CustomEvent::kIterationCompleted));
|
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted));
|
||||||
}
|
}
|
||||||
iteration_state_ = IterationState::kCompleted;
|
iteration_state_ = IterationState::kCompleted;
|
||||||
}
|
}
|
||||||
|
@ -171,7 +171,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
||||||
sync_iter_req.set_rank(rank);
|
sync_iter_req.set_rank(rank);
|
||||||
|
|
||||||
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
|
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
|
||||||
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, core::TcpUserCommand::kSyncIteration,
|
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration,
|
||||||
&sync_iter_rsp_msg)) {
|
&sync_iter_rsp_msg)) {
|
||||||
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
|
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -189,7 +189,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
void Iteration::HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
MS_LOG(ERROR) << "Message is nullptr.";
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
return;
|
return;
|
||||||
|
@ -224,14 +224,14 @@ bool Iteration::NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const s
|
||||||
notify_leader_to_next_iter_req.set_iter_num(iteration_num_);
|
notify_leader_to_next_iter_req.set_iter_num(iteration_num_);
|
||||||
notify_leader_to_next_iter_req.set_reason(reason);
|
notify_leader_to_next_iter_req.set_reason(reason);
|
||||||
if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank,
|
if (!communicator_->SendPbRequest(notify_leader_to_next_iter_req, kLeaderServerRank,
|
||||||
core::TcpUserCommand::kNotifyLeaderToNextIter)) {
|
ps::core::TcpUserCommand::kNotifyLeaderToNextIter)) {
|
||||||
MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed.";
|
MS_LOG(WARNING) << "Sending notify leader server to proceed next iteration request to leader server 0 failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
void Iteration::HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -278,7 +278,7 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
|
||||||
|
|
||||||
std::vector<uint32_t> offline_servers = {};
|
std::vector<uint32_t> offline_servers = {};
|
||||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||||
if (!communicator_->SendPbRequest(prepare_next_iter_req, i, core::TcpUserCommand::kPrepareForNextIter)) {
|
if (!communicator_->SendPbRequest(prepare_next_iter_req, i, ps::core::TcpUserCommand::kPrepareForNextIter)) {
|
||||||
MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later.";
|
MS_LOG(WARNING) << "Sending prepare for next iteration request to server " << i << " failed. Retry later.";
|
||||||
offline_servers.push_back(i);
|
offline_servers.push_back(i);
|
||||||
continue;
|
continue;
|
||||||
|
@ -289,17 +289,18 @@ bool Iteration::BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, cons
|
||||||
std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
|
std::for_each(offline_servers.begin(), offline_servers.end(), [&](uint32_t rank) {
|
||||||
// Should avoid endless loop if the server communicator is stopped.
|
// Should avoid endless loop if the server communicator is stopped.
|
||||||
while (communicator_->running() &&
|
while (communicator_->running() &&
|
||||||
!communicator_->SendPbRequest(prepare_next_iter_req, rank, core::TcpUserCommand::kPrepareForNextIter)) {
|
!communicator_->SendPbRequest(prepare_next_iter_req, rank, ps::core::TcpUserCommand::kPrepareForNextIter)) {
|
||||||
MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
|
MS_LOG(WARNING) << "Retry sending prepare for next iteration request to server " << rank
|
||||||
<< " failed. The server has not recovered yet.";
|
<< " failed. The server has not recovered yet.";
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));
|
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationForPrepareForNextIter));
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
|
MS_LOG(INFO) << "Offline server " << rank << " preparing for next iteration success.";
|
||||||
});
|
});
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -329,7 +330,7 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
|
||||||
proceed_to_next_iter_req.set_last_iter_num(iteration_num_);
|
proceed_to_next_iter_req.set_last_iter_num(iteration_num_);
|
||||||
proceed_to_next_iter_req.set_reason(reason);
|
proceed_to_next_iter_req.set_reason(reason);
|
||||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||||
if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, core::TcpUserCommand::kProceedToNextIter)) {
|
if (!communicator_->SendPbRequest(proceed_to_next_iter_req, i, ps::core::TcpUserCommand::kProceedToNextIter)) {
|
||||||
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
|
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -339,7 +340,7 @@ bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const st
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
void Iteration::HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -388,7 +389,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
|
||||||
EndLastIterRequest end_last_iter_req;
|
EndLastIterRequest end_last_iter_req;
|
||||||
end_last_iter_req.set_last_iter_num(last_iter_num);
|
end_last_iter_req.set_last_iter_num(last_iter_num);
|
||||||
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
for (uint32_t i = 1; i < IntToUint(server_node_->server_num()); i++) {
|
||||||
if (!communicator_->SendPbRequest(end_last_iter_req, i, core::TcpUserCommand::kEndLastIter)) {
|
if (!communicator_->SendPbRequest(end_last_iter_req, i, ps::core::TcpUserCommand::kEndLastIter)) {
|
||||||
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
|
MS_LOG(WARNING) << "Sending proceed to next iteration request to server " << i << " failed.";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -398,7 +399,7 @@ bool Iteration::BroadcastEndLastIterRequest(uint64_t last_iter_num) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message) {
|
void Iteration::HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -429,9 +430,9 @@ void Iteration::EndLastIter() {
|
||||||
MS_LOG(INFO) << "End the last iteration " << iteration_num_;
|
MS_LOG(INFO) << "End the last iteration " << iteration_num_;
|
||||||
iteration_num_++;
|
iteration_num_++;
|
||||||
// After the job is done, reset the iteration to the initial number and reset ModelStore.
|
// After the job is done, reset the iteration to the initial number and reset ModelStore.
|
||||||
if (iteration_num_ > PSContext::instance()->fl_iteration_num()) {
|
if (iteration_num_ > ps::PSContext::instance()->fl_iteration_num()) {
|
||||||
MS_LOG(INFO) << "Iteration loop " << iteration_loop_count_
|
MS_LOG(INFO) << "Iteration loop " << iteration_loop_count_
|
||||||
<< " is completed. Iteration number: " << PSContext::instance()->fl_iteration_num();
|
<< " is completed. Iteration number: " << ps::PSContext::instance()->fl_iteration_num();
|
||||||
iteration_num_ = 1;
|
iteration_num_ = 1;
|
||||||
iteration_loop_count_++;
|
iteration_loop_count_++;
|
||||||
ModelStore::GetInstance().Reset();
|
ModelStore::GetInstance().Reset();
|
||||||
|
@ -444,5 +445,5 @@ void Iteration::EndLastIter() {
|
||||||
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
#include "fl/server/local_meta_store.h"
|
#include "fl/server/local_meta_store.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
enum class IterationState {
|
enum class IterationState {
|
||||||
// This iteration is still in process.
|
// This iteration is still in process.
|
||||||
|
@ -48,16 +48,16 @@ class Iteration {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register callbacks for other servers to synchronize iteration information from leader server.
|
// Register callbacks for other servers to synchronize iteration information from leader server.
|
||||||
void RegisterMessageCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
||||||
|
|
||||||
// Register event callbacks for iteration state synchronization.
|
// Register event callbacks for iteration state synchronization.
|
||||||
void RegisterEventCallback(const std::shared_ptr<core::ServerNode> &server_node);
|
void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node);
|
||||||
|
|
||||||
// Add a round for the iteration. This method will be called multiple times for each round.
|
// Add a round for the iteration. This method will be called multiple times for each round.
|
||||||
void AddRound(const std::shared_ptr<Round> &round);
|
void AddRound(const std::shared_ptr<Round> &round);
|
||||||
|
|
||||||
// Initialize all the rounds in the iteration.
|
// Initialize all the rounds in the iteration.
|
||||||
void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators,
|
void InitRounds(const std::vector<std::shared_ptr<ps::core::CommunicatorBase>> &communicators,
|
||||||
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb);
|
||||||
|
|
||||||
// This method will control servers to proceed to next iteration.
|
// This method will control servers to proceed to next iteration.
|
||||||
|
@ -104,7 +104,7 @@ class Iteration {
|
||||||
|
|
||||||
// Synchronize iteration form the leader server(Rank 0).
|
// Synchronize iteration form the leader server(Rank 0).
|
||||||
bool SyncIteration(uint32_t rank);
|
bool SyncIteration(uint32_t rank);
|
||||||
void HandleSyncIterationRequest(const std::shared_ptr<core::MessageHandler> &message);
|
void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// The request for moving to next iteration is not reentrant.
|
// The request for moving to next iteration is not reentrant.
|
||||||
bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num);
|
bool IsMoveToNextIterRequestReentrant(uint64_t iteration_num);
|
||||||
|
@ -112,28 +112,28 @@ class Iteration {
|
||||||
// The methods for moving to next iteration for all the servers.
|
// The methods for moving to next iteration for all the servers.
|
||||||
// Step 1: follower servers notify leader server that they need to move to next iteration.
|
// Step 1: follower servers notify leader server that they need to move to next iteration.
|
||||||
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
|
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
|
||||||
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode.
|
// Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode.
|
||||||
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
||||||
void HandlePrepareForNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
// The server prepare for the next iteration. This method will switch the server to safemode.
|
// The server prepare for the next iteration. This method will switch the server to safemode.
|
||||||
void PrepareForNextIter();
|
void PrepareForNextIter();
|
||||||
|
|
||||||
// Step 3: leader server broadcast to all follower servers to move to next iteration.
|
// Step 3: leader server broadcast to all follower servers to move to next iteration.
|
||||||
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
||||||
void HandleMoveToNextIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
// Move to next iteration. Store last iterations model and reset all the rounds.
|
// Move to next iteration. Store last iterations model and reset all the rounds.
|
||||||
void Next(bool is_iteration_valid, const std::string &reason);
|
void Next(bool is_iteration_valid, const std::string &reason);
|
||||||
|
|
||||||
// Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode.
|
// Step 4: leader server broadcasts to all follower servers to end last iteration and cancel the safemode.
|
||||||
bool BroadcastEndLastIterRequest(uint64_t iteration_num);
|
bool BroadcastEndLastIterRequest(uint64_t iteration_num);
|
||||||
void HandleEndLastIterRequest(const std::shared_ptr<core::MessageHandler> &message);
|
void HandleEndLastIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
// The server end the last iteration. This method will increase the iteration number and cancel the safemode.
|
// The server end the last iteration. This method will increase the iteration number and cancel the safemode.
|
||||||
void EndLastIter();
|
void EndLastIter();
|
||||||
|
|
||||||
std::shared_ptr<core::ServerNode> server_node_;
|
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||||
std::shared_ptr<core::TcpCommunicator> communicator_;
|
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
|
||||||
|
|
||||||
// All the rounds in the server.
|
// All the rounds in the server.
|
||||||
std::vector<std::shared_ptr<Round>> rounds_;
|
std::vector<std::shared_ptr<Round>> rounds_;
|
||||||
|
@ -155,6 +155,6 @@ class Iteration {
|
||||||
std::mutex pinned_mtx_;
|
std::mutex pinned_mtx_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_H_
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "fl/server/iteration_timer.h"
|
#include "fl/server/iteration_timer.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
||||||
if (running_.load()) {
|
if (running_.load()) {
|
||||||
|
@ -52,5 +52,5 @@ bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) const
|
||||||
|
|
||||||
bool IterationTimer::IsRunning() const { return running_; }
|
bool IterationTimer::IsRunning() const { return running_; }
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
@ -24,7 +24,7 @@
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// IterationTimer controls the time window for the purpose of eliminating trailing time of each iteration.
|
// IterationTimer controls the time window for the purpose of eliminating trailing time of each iteration.
|
||||||
class IterationTimer {
|
class IterationTimer {
|
||||||
|
@ -59,6 +59,6 @@ class IterationTimer {
|
||||||
TimeOutCb timeout_callback_;
|
TimeOutCb timeout_callback_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_TIMER_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_ITERATION_TIMER_H_
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
#include "fl/server/kernel/params_info.h"
|
#include "fl/server/kernel/params_info.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation.
|
// AggregationKernel is the kernel for weight, grad or other kinds of parameters' aggregation.
|
||||||
|
@ -99,6 +99,6 @@ class AggregationKernel : public CPUKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_H_
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
bool AggregationKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) {
|
bool AggregationKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) {
|
||||||
|
@ -67,5 +67,5 @@ bool AggregationKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNod
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -24,7 +24,7 @@
|
||||||
#include "fl/server/kernel/aggregation_kernel.h"
|
#include "fl/server/kernel/aggregation_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>;
|
using AggregationKernelCreator = std::function<std::shared_ptr<AggregationKernel>()>;
|
||||||
|
@ -51,6 +51,7 @@ class AggregationKernelRegister {
|
||||||
AggregationKernelCreator &&creator) {
|
AggregationKernelCreator &&creator) {
|
||||||
AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
|
AggregationKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
|
||||||
}
|
}
|
||||||
|
~AggregationKernelRegister() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Register aggregation kernel with one template type T.
|
// Register aggregation kernel with one template type T.
|
||||||
|
@ -66,6 +67,6 @@ class AggregationKernelRegister {
|
||||||
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); });
|
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T, S>>(); });
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_AGGREGATION_KERNEL_FACTORY_H_
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "fl/server/kernel/apply_momentum_kernel.h"
|
#include "fl/server/kernel/apply_momentum_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
REG_OPTIMIZER_KERNEL(ApplyMomentum,
|
REG_OPTIMIZER_KERNEL(ApplyMomentum,
|
||||||
|
@ -30,5 +30,5 @@ REG_OPTIMIZER_KERNEL(ApplyMomentum,
|
||||||
ApplyMomentumKernel, float)
|
ApplyMomentumKernel, float)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -25,7 +25,7 @@
|
||||||
#include "fl/server/kernel/optimizer_kernel_factory.h"
|
#include "fl/server/kernel/optimizer_kernel_factory.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
using mindspore::kernel::ApplyMomentumCPUKernel;
|
using mindspore::kernel::ApplyMomentumCPUKernel;
|
||||||
|
@ -57,6 +57,6 @@ class ApplyMomentumKernel : public ApplyMomentumCPUKernel, public OptimizerKerne
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_APPLY_MOMENTUM_KERNEL_H_
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "fl/server/kernel/dense_grad_accum_kernel.h"
|
#include "fl/server/kernel/dense_grad_accum_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
REG_AGGREGATION_KERNEL(
|
REG_AGGREGATION_KERNEL(
|
||||||
|
@ -26,5 +26,5 @@ REG_AGGREGATION_KERNEL(
|
||||||
DenseGradAccumKernel, float)
|
DenseGradAccumKernel, float)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
#include "fl/server/kernel/aggregation_kernel_factory.h"
|
#include "fl/server/kernel/aggregation_kernel_factory.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -90,6 +90,6 @@ class DenseGradAccumKernel : public AggregationKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_DENSE_GRAD_ACCUM_KERNEL_H_
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "fl/server/kernel/fed_avg_kernel.h"
|
#include "fl/server/kernel/fed_avg_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
REG_AGGREGATION_KERNEL_TWO(FedAvg,
|
REG_AGGREGATION_KERNEL_TWO(FedAvg,
|
||||||
|
@ -29,5 +29,5 @@ REG_AGGREGATION_KERNEL_TWO(FedAvg,
|
||||||
FedAvgKernel, float, size_t)
|
FedAvgKernel, float, size_t)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -31,7 +31,7 @@
|
||||||
#include "fl/server/kernel/aggregation_kernel_factory.h"
|
#include "fl/server/kernel/aggregation_kernel_factory.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from
|
// The implementation for the federated average. We do weighted average for the weights. The uploaded weights from
|
||||||
|
@ -42,7 +42,13 @@ namespace kernel {
|
||||||
template <typename T, typename S>
|
template <typename T, typename S>
|
||||||
class FedAvgKernel : public AggregationKernel {
|
class FedAvgKernel : public AggregationKernel {
|
||||||
public:
|
public:
|
||||||
FedAvgKernel() : participated_(false) {}
|
FedAvgKernel()
|
||||||
|
: cnode_weight_idx_(0),
|
||||||
|
weight_addr_(nullptr),
|
||||||
|
data_size_addr_(nullptr),
|
||||||
|
new_weight_addr_(nullptr),
|
||||||
|
new_data_size_addr_(nullptr),
|
||||||
|
participated_(false) {}
|
||||||
~FedAvgKernel() override = default;
|
~FedAvgKernel() override = default;
|
||||||
|
|
||||||
void InitKernel(const CNodePtr &kernel_node) override {
|
void InitKernel(const CNodePtr &kernel_node) override {
|
||||||
|
@ -68,13 +74,13 @@ class FedAvgKernel : public AggregationKernel {
|
||||||
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first;
|
AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, cnode_weight_idx_), 0).first;
|
||||||
MS_EXCEPTION_IF_NULL(weight_node);
|
MS_EXCEPTION_IF_NULL(weight_node);
|
||||||
name_ = cnode_name + "." + weight_node->fullname_with_scope();
|
name_ = cnode_name + "." + weight_node->fullname_with_scope();
|
||||||
first_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) {
|
first_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
||||||
std::unique_lock<std::mutex> lock(weight_mutex_);
|
std::unique_lock<std::mutex> lock(weight_mutex_);
|
||||||
if (!participated_) {
|
if (!participated_) {
|
||||||
ClearWeightAndDataSize();
|
ClearWeightAndDataSize();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
last_cnt_handler_ = [&](std::shared_ptr<core::MessageHandler>) {
|
last_cnt_handler_ = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
||||||
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
|
T *weight_addr = reinterpret_cast<T *>(weight_addr_->addr);
|
||||||
size_t weight_size = weight_addr_->size;
|
size_t weight_size = weight_addr_->size;
|
||||||
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
|
S *data_size_addr = reinterpret_cast<S *>(data_size_addr_->addr);
|
||||||
|
@ -193,7 +199,7 @@ class FedAvgKernel : public AggregationKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_FED_AVG_KERNEL_H_
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
#include "fl/server/kernel/params_info.h"
|
#include "fl/server/kernel/params_info.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory
|
// KernelFactory is used to select and build kernels in server. It's the base class of OptimizerKernelFactory
|
||||||
|
@ -87,6 +87,6 @@ class KernelFactory {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_KERNEL_FACTORY_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_KERNEL_FACTORY_H_
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -28,7 +28,7 @@
|
||||||
#include "fl/server/kernel/params_info.h"
|
#include "fl/server/kernel/params_info.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
using mindspore::kernel::IsSameShape;
|
using mindspore::kernel::IsSameShape;
|
||||||
|
@ -92,6 +92,6 @@ class OptimizerKernel : public CPUKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_H_
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) {
|
bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) {
|
||||||
|
@ -66,5 +66,5 @@ bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodeP
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -24,7 +24,7 @@
|
||||||
#include "fl/server/kernel/optimizer_kernel.h"
|
#include "fl/server/kernel/optimizer_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
using OptimizerKernelCreator = std::function<std::shared_ptr<OptimizerKernel>()>;
|
using OptimizerKernelCreator = std::function<std::shared_ptr<OptimizerKernel>()>;
|
||||||
|
@ -50,6 +50,7 @@ class OptimizerKernelRegister {
|
||||||
OptimizerKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, OptimizerKernelCreator &&creator) {
|
OptimizerKernelRegister(const std::string &name, const ParamsInfo ¶ms_info, OptimizerKernelCreator &&creator) {
|
||||||
OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
|
OptimizerKernelFactory::GetInstance().Register(name, params_info, std::move(creator));
|
||||||
}
|
}
|
||||||
|
~OptimizerKernelRegister() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Register optimizer kernel with one template type T.
|
// Register optimizer kernel with one template type T.
|
||||||
|
@ -59,6 +60,6 @@ class OptimizerKernelRegister {
|
||||||
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); });
|
#NAME, PARAMS_INFO, []() { return std::make_shared<CLASS<T>>(); });
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_OPTIMIZER_KERNEL_FACTORY_H_
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
ParamsInfo &ParamsInfo::AddInputNameType(const std::string &name, TypeId type) {
|
ParamsInfo &ParamsInfo::AddInputNameType(const std::string &name, TypeId type) {
|
||||||
|
@ -64,5 +64,5 @@ const std::vector<std::string> &ParamsInfo::workspace_names() const { return wor
|
||||||
const std::vector<std::string> &ParamsInfo::outputs_names() const { return outputs_names_; }
|
const std::vector<std::string> &ParamsInfo::outputs_names() const { return outputs_names_; }
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -23,7 +23,7 @@
|
||||||
#include "ir/dtype/type_id.h"
|
#include "ir/dtype/type_id.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// ParamsInfo is used for server computation kernel's register, e.g, ApplyMomentumKernel, FedAvgKernel, etc.
|
// ParamsInfo is used for server computation kernel's register, e.g, ApplyMomentumKernel, FedAvgKernel, etc.
|
||||||
|
@ -65,6 +65,6 @@ class ParamsInfo {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PARAMS_INFO_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PARAMS_INFO_H_
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
#include "schema/cipher_generated.h"
|
#include "schema/cipher_generated.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void ClientListKernel::InitKernel(size_t) {
|
void ClientListKernel::InitKernel(size_t) {
|
||||||
|
@ -150,7 +150,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
|
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
|
||||||
return true;
|
return true;
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
|
|
||||||
bool ClientListKernel::Reset() {
|
bool ClientListKernel::Reset() {
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
|
@ -196,5 +196,5 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli
|
||||||
REG_ROUND_KERNEL(getClientList, ClientListKernel)
|
REG_ROUND_KERNEL(getClientList, ClientListKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class ClientListKernel : public RoundKernel {
|
class ClientListKernel : public RoundKernel {
|
||||||
|
@ -50,6 +50,6 @@ class ClientListKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void ExchangeKeysKernel::InitKernel(size_t) {
|
void ExchangeKeysKernel::InitKernel(size_t) {
|
||||||
|
@ -100,5 +100,5 @@ bool ExchangeKeysKernel::Reset() {
|
||||||
REG_ROUND_KERNEL(exchangeKeys, ExchangeKeysKernel)
|
REG_ROUND_KERNEL(exchangeKeys, ExchangeKeysKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
@ -25,7 +25,7 @@
|
||||||
#include "fl/armour/cipher/cipher_keys.h"
|
#include "fl/armour/cipher/cipher_keys.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class ExchangeKeysKernel : public RoundKernel {
|
class ExchangeKeysKernel : public RoundKernel {
|
||||||
|
@ -44,7 +44,7 @@ class ExchangeKeysKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void GetKeysKernel::InitKernel(size_t) {
|
void GetKeysKernel::InitKernel(size_t) {
|
||||||
|
@ -99,5 +99,5 @@ bool GetKeysKernel::Reset() {
|
||||||
REG_ROUND_KERNEL(getKeys, GetKeysKernel)
|
REG_ROUND_KERNEL(getKeys, GetKeysKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
@ -25,7 +25,7 @@
|
||||||
#include "fl/armour/cipher/cipher_keys.h"
|
#include "fl/armour/cipher/cipher_keys.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class GetKeysKernel : public RoundKernel {
|
class GetKeysKernel : public RoundKernel {
|
||||||
|
@ -44,7 +44,7 @@ class GetKeysKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
#include "fl/server/model_store.h"
|
#include "fl/server/model_store.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void GetModelKernel::InitKernel(size_t) {
|
void GetModelKernel::InitKernel(size_t) {
|
||||||
|
@ -133,5 +133,5 @@ void GetModelKernel::BuildGetModelRsp(const std::shared_ptr<FBBuilder> &fbb, con
|
||||||
REG_ROUND_KERNEL(getModel, GetModelKernel)
|
REG_ROUND_KERNEL(getModel, GetModelKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_MODEL_KERNEL_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -27,13 +27,13 @@
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
constexpr uint32_t kPrintGetModelForEveryRetryTime = 50;
|
constexpr uint32_t kPrintGetModelForEveryRetryTime = 50;
|
||||||
class GetModelKernel : public RoundKernel {
|
class GetModelKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
GetModelKernel() = default;
|
GetModelKernel() : executor_(nullptr), iteration_time_window_(0), retry_count_(0) {}
|
||||||
~GetModelKernel() override = default;
|
~GetModelKernel() override = default;
|
||||||
|
|
||||||
void InitKernel(size_t) override;
|
void InitKernel(size_t) override;
|
||||||
|
@ -58,6 +58,6 @@ class GetModelKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "fl/armour/cipher/cipher_shares.h"
|
#include "fl/armour/cipher/cipher_shares.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void GetSecretsKernel::InitKernel(size_t) {
|
void GetSecretsKernel::InitKernel(size_t) {
|
||||||
|
@ -102,5 +102,5 @@ bool GetSecretsKernel::Reset() {
|
||||||
REG_ROUND_KERNEL(getSecrets, GetSecretsKernel)
|
REG_ROUND_KERNEL(getSecrets, GetSecretsKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
@ -25,7 +25,7 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class GetSecretsKernel : public RoundKernel {
|
class GetSecretsKernel : public RoundKernel {
|
||||||
|
@ -44,7 +44,7 @@ class GetSecretsKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||||
|
|
|
@ -22,7 +22,7 @@
|
||||||
#include "fl/server/model_store.h"
|
#include "fl/server/model_store.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void PullWeightKernel::InitKernel(size_t) {
|
void PullWeightKernel::InitKernel(size_t) {
|
||||||
|
@ -137,5 +137,5 @@ void PullWeightKernel::BuildPullWeightRsp(std::shared_ptr<FBBuilder> fbb, const
|
||||||
REG_ROUND_KERNEL(pullWeight, PullWeightKernel)
|
REG_ROUND_KERNEL(pullWeight, PullWeightKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -27,7 +27,7 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
constexpr uint32_t kPrintPullWeightForEveryRetryTime = 500;
|
constexpr uint32_t kPrintPullWeightForEveryRetryTime = 500;
|
||||||
|
@ -53,6 +53,6 @@ class PullWeightKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PULL_WEIGHT_KERNEL_H_
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "fl/server/kernel/round/push_weight_kernel.h"
|
#include "fl/server/kernel/round/push_weight_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void PushWeightKernel::InitKernel(size_t) {
|
void PushWeightKernel::InitKernel(size_t) {
|
||||||
|
@ -60,8 +60,8 @@ bool PushWeightKernel::Reset() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
void PushWeightKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
|
||||||
if (PSContext::instance()->resetter_round() == ResetterRound::kPushWeight) {
|
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kPushWeight) {
|
||||||
FinishIteration();
|
FinishIteration();
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
@ -136,5 +136,5 @@ void PushWeightKernel::BuildPushWeightRsp(std::shared_ptr<FBBuilder> fbb, const
|
||||||
REG_ROUND_KERNEL(pushWeight, PushWeightKernel)
|
REG_ROUND_KERNEL(pushWeight, PushWeightKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -27,7 +27,7 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class PushWeightKernel : public RoundKernel {
|
class PushWeightKernel : public RoundKernel {
|
||||||
|
@ -39,7 +39,7 @@ class PushWeightKernel : public RoundKernel {
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs);
|
const std::vector<AddressPtr> &outputs);
|
||||||
bool Reset() override;
|
bool Reset() override;
|
||||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
|
bool PushWeight(std::shared_ptr<FBBuilder> fbb, const schema::RequestPushWeight *push_weight_req);
|
||||||
|
@ -52,6 +52,6 @@ class PushWeightKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_WEIGHT_KERNEL_H_
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
|
void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
|
||||||
|
@ -34,17 +34,17 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
|
||||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
|
auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
||||||
MS_LOG(INFO) << "start FinishIteration";
|
MS_LOG(INFO) << "start FinishIteration";
|
||||||
FinishIteration();
|
FinishIteration();
|
||||||
MS_LOG(INFO) << "end FinishIteration";
|
MS_LOG(INFO) << "end FinishIteration";
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) { return; };
|
auto first_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) { return; };
|
||||||
name_unmask_ = "UnMaskKernel";
|
name_unmask_ = "UnMaskKernel";
|
||||||
MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : "
|
MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : "
|
||||||
<< LocalMetaStore::GetInstance().curr_iter_num();
|
<< LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, PSContext::instance()->initial_server_num(),
|
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, ps::PSContext::instance()->initial_server_num(),
|
||||||
{first_cnt_handler, last_cnt_handler});
|
{first_cnt_handler, last_cnt_handler});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,9 +134,9 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
if (PSContext::instance()->encrypt_type() == kPWEncryptType) {
|
if (ps::PSContext::instance()->encrypt_type() == ps::kPWEncryptType) {
|
||||||
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
|
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||||
}
|
}
|
||||||
|
@ -164,5 +164,5 @@ bool ReconstructSecretsKernel::Reset() {
|
||||||
REG_ROUND_KERNEL(reconstructSecrets, ReconstructSecretsKernel)
|
REG_ROUND_KERNEL(reconstructSecrets, ReconstructSecretsKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -27,7 +27,7 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class ReconstructSecretsKernel : public RoundKernel {
|
class ReconstructSecretsKernel : public RoundKernel {
|
||||||
|
@ -39,7 +39,7 @@ class ReconstructSecretsKernel : public RoundKernel {
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
bool Reset() override;
|
bool Reset() override;
|
||||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string name_unmask_;
|
std::string name_unmask_;
|
||||||
|
@ -49,6 +49,6 @@ class ReconstructSecretsKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) {
|
RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_(""), running_(true) {
|
||||||
|
@ -61,9 +61,9 @@ RoundKernel::~RoundKernel() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) { return; }
|
void RoundKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
|
||||||
|
|
||||||
void RoundKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) { return; }
|
void RoundKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) { return; }
|
||||||
|
|
||||||
void RoundKernel::StopTimer() const {
|
void RoundKernel::StopTimer() const {
|
||||||
if (stop_timer_cb_) {
|
if (stop_timer_cb_) {
|
||||||
|
@ -129,5 +129,5 @@ void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const v
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -35,7 +35,7 @@
|
||||||
#include "fl/server/distributed_metadata_store.h"
|
#include "fl/server/distributed_metadata_store.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
|
// RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round
|
||||||
|
@ -67,8 +67,8 @@ class RoundKernel : virtual public CPUKernel {
|
||||||
// The counter event handlers for DistributedCountService.
|
// The counter event handlers for DistributedCountService.
|
||||||
// The callbacks when first message and last message for this round kernel is received.
|
// The callbacks when first message and last message for this round kernel is received.
|
||||||
// These methods is called by class DistributedCountService and triggered by counting server.
|
// These methods is called by class DistributedCountService and triggered by counting server.
|
||||||
virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
virtual void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
virtual void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// Called when this round is finished. This round timer's Stop method will be called.
|
// Called when this round is finished. This round timer's Stop method will be called.
|
||||||
void StopTimer() const;
|
void StopTimer() const;
|
||||||
|
@ -123,6 +123,6 @@ class RoundKernel : virtual public CPUKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
RoundKernelFactory &RoundKernelFactory::GetInstance() {
|
RoundKernelFactory &RoundKernelFactory::GetInstance() {
|
||||||
|
@ -40,5 +40,5 @@ std::shared_ptr<RoundKernel> RoundKernelFactory::Create(const std::string &name)
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -25,7 +25,7 @@
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>;
|
using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>;
|
||||||
|
@ -50,6 +50,7 @@ class RoundKernelRegister {
|
||||||
RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) {
|
RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) {
|
||||||
RoundKernelFactory::GetInstance().Register(name, std::move(creator));
|
RoundKernelFactory::GetInstance().Register(name, std::move(creator));
|
||||||
}
|
}
|
||||||
|
~RoundKernelRegister() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REG_ROUND_KERNEL(NAME, CLASS) \
|
#define REG_ROUND_KERNEL(NAME, CLASS) \
|
||||||
|
@ -57,6 +58,6 @@ class RoundKernelRegister {
|
||||||
static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); });
|
static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); });
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void ShareSecretsKernel::InitKernel(size_t) {
|
void ShareSecretsKernel::InitKernel(size_t) {
|
||||||
|
@ -101,5 +101,5 @@ bool ShareSecretsKernel::Reset() {
|
||||||
REG_ROUND_KERNEL(shareSecrets, ShareSecretsKernel)
|
REG_ROUND_KERNEL(shareSecrets, ShareSecretsKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
@ -25,7 +25,7 @@
|
||||||
#include "fl/armour/cipher/cipher_shares.h"
|
#include "fl/armour/cipher/cipher_shares.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class ShareSecretsKernel : public RoundKernel {
|
class ShareSecretsKernel : public RoundKernel {
|
||||||
|
@ -44,7 +44,7 @@ class ShareSecretsKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void StartFLJobKernel::InitKernel(size_t) {
|
void StartFLJobKernel::InitKernel(size_t) {
|
||||||
|
@ -113,8 +113,8 @@ bool StartFLJobKernel::Reset() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
|
||||||
iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
|
iter_next_req_timestamp_ = CURRENT_TIME_MILLI.count() + iteration_time_window_;
|
||||||
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||||
// The first startFLJob request means a new iteration starts running.
|
// The first startFLJob request means a new iteration starts running.
|
||||||
Iteration::GetInstance().SetIterationRunning();
|
Iteration::GetInstance().SetIterationRunning();
|
||||||
|
@ -194,8 +194,8 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
std::map<std::string, AddressPtr> feature_maps) {
|
std::map<std::string, AddressPtr> feature_maps) {
|
||||||
auto fbs_reason = fbb->CreateString(reason);
|
auto fbs_reason = fbb->CreateString(reason);
|
||||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||||
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode());
|
auto fbs_server_mode = fbb->CreateString(ps::PSContext::instance()->server_mode());
|
||||||
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
|
auto fbs_fl_name = fbb->CreateString(ps::PSContext::instance()->fl_name());
|
||||||
|
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
auto *param = armour::CipherInit::GetInstance().GetPublicParams();
|
auto *param = armour::CipherInit::GetInstance().GetPublicParams();
|
||||||
|
@ -206,7 +206,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
float dp_eps = param->dp_eps;
|
float dp_eps = param->dp_eps;
|
||||||
float dp_delta = param->dp_delta;
|
float dp_delta = param->dp_delta;
|
||||||
float dp_norm_clip = param->dp_norm_clip;
|
float dp_norm_clip = param->dp_norm_clip;
|
||||||
auto encrypt_type = fbb->CreateString(PSContext::instance()->encrypt_type());
|
auto encrypt_type = fbb->CreateString(ps::PSContext::instance()->encrypt_type());
|
||||||
|
|
||||||
auto cipher_public_params =
|
auto cipher_public_params =
|
||||||
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
|
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
|
||||||
|
@ -215,10 +215,10 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||||
fl_plan_builder.add_fl_name(fbs_fl_name);
|
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||||
fl_plan_builder.add_server_mode(fbs_server_mode);
|
fl_plan_builder.add_server_mode(fbs_server_mode);
|
||||||
fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num());
|
fl_plan_builder.add_iterations(ps::PSContext::instance()->fl_iteration_num());
|
||||||
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
|
fl_plan_builder.add_epochs(ps::PSContext::instance()->client_epoch_num());
|
||||||
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
|
fl_plan_builder.add_mini_batch(ps::PSContext::instance()->client_batch_size());
|
||||||
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate());
|
fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate());
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
fl_plan_builder.add_cipher(cipher_public_params);
|
fl_plan_builder.add_cipher(cipher_public_params);
|
||||||
#endif
|
#endif
|
||||||
|
@ -250,5 +250,5 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
REG_ROUND_KERNEL(startFLJob, StartFLJobKernel)
|
REG_ROUND_KERNEL(startFLJob, StartFLJobKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -27,7 +27,7 @@
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class StartFLJobKernel : public RoundKernel {
|
class StartFLJobKernel : public RoundKernel {
|
||||||
|
@ -40,7 +40,7 @@ class StartFLJobKernel : public RoundKernel {
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
bool Reset() override;
|
bool Reset() override;
|
||||||
|
|
||||||
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Returns whether the startFLJob count of this iteration has reached the threshold.
|
// Returns whether the startFLJob count of this iteration has reached the threshold.
|
||||||
|
@ -74,6 +74,6 @@ class StartFLJobKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_START_FL_JOB_KERNEL_H_
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "fl/server/kernel/round/update_model_kernel.h"
|
#include "fl/server/kernel/round/update_model_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
void UpdateModelKernel::InitKernel(size_t threshold_count) {
|
||||||
|
@ -87,8 +87,8 @@ bool UpdateModelKernel::Reset() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &) {
|
void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (PSContext::instance()->resetter_round() == ResetterRound::kUpdateModel) {
|
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kUpdateModel) {
|
||||||
while (!executor_->IsAllWeightAggregationDone()) {
|
while (!executor_->IsAllWeightAggregationDone()) {
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||||
}
|
}
|
||||||
|
@ -96,7 +96,7 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
|
||||||
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
|
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
|
||||||
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
|
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
|
||||||
<< total_data_size;
|
<< total_data_size;
|
||||||
if (PSContext::instance()->encrypt_type() != kPWEncryptType) {
|
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
|
||||||
FinishIteration();
|
FinishIteration();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -226,5 +226,5 @@ void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fb
|
||||||
REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
|
REG_ROUND_KERNEL(updateModel, UpdateModelKernel)
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -27,7 +27,7 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
|
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
|
||||||
|
@ -44,7 +44,7 @@ class UpdateModelKernel : public RoundKernel {
|
||||||
bool Reset() override;
|
bool Reset() override;
|
||||||
|
|
||||||
// In some cases, the last updateModel message means this server iteration is finished.
|
// In some cases, the last updateModel message means this server iteration is finished.
|
||||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
|
bool ReachThresholdForUpdateModel(const std::shared_ptr<FBBuilder> &fbb);
|
||||||
|
@ -62,6 +62,6 @@ class UpdateModelKernel : public RoundKernel {
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_UPDATE_MODEL_KERNEL_H_
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
#include "fl/server/local_meta_store.h"
|
#include "fl/server/local_meta_store.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
void LocalMetaStore::remove_value(const std::string &name) {
|
void LocalMetaStore::remove_value(const std::string &name) {
|
||||||
std::unique_lock<std::mutex> lock(mtx_);
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
|
@ -41,5 +41,5 @@ const size_t LocalMetaStore::curr_iter_num() {
|
||||||
return curr_iter_num_;
|
return curr_iter_num_;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
|
||||||
|
|
||||||
#include <any>
|
#include <any>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
@ -24,7 +24,7 @@
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// LocalMetaStore class is used for metadata storage of this server process.
|
// LocalMetaStore class is used for metadata storage of this server process.
|
||||||
// For example, the current iteration number, time windows for round kernels, etc.
|
// For example, the current iteration number, time windows for round kernels, etc.
|
||||||
|
@ -71,7 +71,7 @@ class LocalMetaStore {
|
||||||
const size_t curr_iter_num();
|
const size_t curr_iter_num();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LocalMetaStore() = default;
|
LocalMetaStore() : key_to_meta_({}), curr_iter_num_(0) {}
|
||||||
~LocalMetaStore() = default;
|
~LocalMetaStore() = default;
|
||||||
LocalMetaStore(const LocalMetaStore &) = delete;
|
LocalMetaStore(const LocalMetaStore &) = delete;
|
||||||
LocalMetaStore &operator=(const LocalMetaStore &) = delete;
|
LocalMetaStore &operator=(const LocalMetaStore &) = delete;
|
||||||
|
@ -83,6 +83,6 @@ class LocalMetaStore {
|
||||||
size_t curr_iter_num_;
|
size_t curr_iter_num_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_LOCAL_META_STORE_H_
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) {
|
void MemoryRegister::RegisterAddressPtr(const std::string &name, const AddressPtr &address) {
|
||||||
addresses_.try_emplace(name, address);
|
addresses_.try_emplace(name, address);
|
||||||
|
@ -32,5 +32,5 @@ void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64
|
||||||
|
|
||||||
void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); }
|
void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); }
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// Memory allocated in server is normally trainable parameters, hyperparameters, gradients, etc.
|
// Memory allocated in server is normally trainable parameters, hyperparameters, gradients, etc.
|
||||||
// MemoryRegister registers the Memory with key-value format where key refers to address's name("grad", "weights",
|
// MemoryRegister registers the Memory with key-value format where key refers to address's name("grad", "weights",
|
||||||
|
@ -88,6 +88,6 @@ class MemoryRegister {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_MEMORY_REGISTER_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_MEMORY_REGISTER_H_
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
void ModelStore::Initialize(uint32_t max_count) {
|
void ModelStore::Initialize(uint32_t max_count) {
|
||||||
if (!Executor::GetInstance().initialized()) {
|
if (!Executor::GetInstance().initialized()) {
|
||||||
|
@ -155,5 +155,5 @@ size_t ModelStore::ComputeModelSize() {
|
||||||
return model_size;
|
return model_size;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -25,7 +25,7 @@
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// The initial iteration number is 0 in server.
|
// The initial iteration number is 0 in server.
|
||||||
constexpr size_t kInitIterationNum = 0;
|
constexpr size_t kInitIterationNum = 0;
|
||||||
|
@ -84,6 +84,6 @@ class ModelStore {
|
||||||
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_MODEL_STORE_H_
|
||||||
|
|
|
@ -23,7 +23,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
|
bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
@ -199,8 +199,8 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
|
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
|
||||||
if (PSContext::instance()->server_mode() == kServerModeFL ||
|
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
||||||
PSContext::instance()->server_mode() == kServerModeHybrid) {
|
ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
|
||||||
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
|
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -321,13 +321,13 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
|
||||||
|
|
||||||
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
|
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
|
||||||
std::vector<std::string> aggregation_algorithm = {};
|
std::vector<std::string> aggregation_algorithm = {};
|
||||||
if (PSContext::instance()->server_mode() == kServerModeFL ||
|
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
||||||
PSContext::instance()->server_mode() == kServerModeHybrid) {
|
ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
|
||||||
aggregation_algorithm.push_back("FedAvg");
|
aggregation_algorithm.push_back("FedAvg");
|
||||||
} else if (PSContext::instance()->server_mode() == kServerModePS) {
|
} else if (ps::PSContext::instance()->server_mode() == ps::kServerModePS) {
|
||||||
aggregation_algorithm.push_back("DenseGradAccum");
|
aggregation_algorithm.push_back("DenseGradAccum");
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Server doesn't support mode " << PSContext::instance()->server_mode();
|
MS_LOG(ERROR) << "Server doesn't support mode " << ps::PSContext::instance()->server_mode();
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
|
MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
|
||||||
|
@ -344,5 +344,5 @@ template bool ParameterAggregator::AssignMemory(std::shared_ptr<kernel::Aggregat
|
||||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info,
|
||||||
std::shared_ptr<MemoryRegister> memory_register);
|
std::shared_ptr<MemoryRegister> memory_register);
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -28,7 +28,7 @@
|
||||||
#include "fl/server/kernel/optimizer_kernel_factory.h"
|
#include "fl/server/kernel/optimizer_kernel_factory.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server
|
// Encapsulate the parameters for a kernel into a struct to make it convenient for ParameterAggregator to launch server
|
||||||
// kernels.
|
// kernels.
|
||||||
|
@ -137,6 +137,6 @@ class ParameterAggregator {
|
||||||
std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
|
std::vector<std::pair<std::shared_ptr<kernel::OptimizerKernel>, KernelParams>> optimizer_kernel_parameters_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_PARAMETER_AGGREGATOR_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_PARAMETER_AGGREGATOR_H_
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "fl/server/iteration.h"
|
#include "fl/server/iteration.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
class Server;
|
class Server;
|
||||||
class Iteration;
|
class Iteration;
|
||||||
|
@ -34,14 +34,14 @@ Round::Round(const std::string &name, bool check_timeout, size_t time_window, bo
|
||||||
threshold_count_(threshold_count),
|
threshold_count_(threshold_count),
|
||||||
server_num_as_threshold_(server_num_as_threshold) {}
|
server_num_as_threshold_(server_num_as_threshold) {}
|
||||||
|
|
||||||
void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||||
FinishIterCb finish_iteration_cb) {
|
FinishIterCb finish_iteration_cb) {
|
||||||
MS_EXCEPTION_IF_NULL(communicator);
|
MS_EXCEPTION_IF_NULL(communicator);
|
||||||
communicator_ = communicator;
|
communicator_ = communicator;
|
||||||
|
|
||||||
// Register callback for round kernel.
|
// Register callback for round kernel.
|
||||||
communicator_->RegisterMsgCallBack(
|
communicator_->RegisterMsgCallBack(
|
||||||
name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); });
|
name_, [&](std::shared_ptr<ps::core::MessageHandler> message) { LaunchRoundKernel(message); });
|
||||||
|
|
||||||
// Callback when the iteration is finished.
|
// Callback when the iteration is finished.
|
||||||
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
|
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
|
||||||
|
@ -106,7 +106,7 @@ void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel)
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message) {
|
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
if (message == nullptr) {
|
if (message == nullptr) {
|
||||||
MS_LOG(ERROR) << "Message is nullptr.";
|
MS_LOG(ERROR) << "Message is nullptr.";
|
||||||
return;
|
return;
|
||||||
|
@ -152,7 +152,7 @@ bool Round::check_timeout() const { return check_timeout_; }
|
||||||
|
|
||||||
size_t Round::time_window() const { return time_window_; }
|
size_t Round::time_window() const { return time_window_; }
|
||||||
|
|
||||||
void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
|
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
|
||||||
// The timer starts only after the first count event is triggered by DistributedCountService.
|
// The timer starts only after the first count event is triggered by DistributedCountService.
|
||||||
if (check_timeout_) {
|
if (check_timeout_) {
|
||||||
|
@ -164,7 +164,7 @@ void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &messa
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
|
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
|
||||||
// Same as the first count event, the timer must be stopped by DistributedCountService.
|
// Same as the first count event, the timer must be stopped by DistributedCountService.
|
||||||
if (check_timeout_) {
|
if (check_timeout_) {
|
||||||
|
@ -176,5 +176,5 @@ void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &messag
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -26,7 +26,7 @@
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// Round helps server to handle network round messages and launch round kernels. One iteration in server consists of
|
// Round helps server to handle network round messages and launch round kernels. One iteration in server consists of
|
||||||
// multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting
|
// multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting
|
||||||
|
@ -37,7 +37,7 @@ class Round {
|
||||||
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
|
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
|
||||||
~Round() = default;
|
~Round() = default;
|
||||||
|
|
||||||
void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, TimeOutCb timeout_cb,
|
||||||
FinishIterCb finish_iteration_cb);
|
FinishIterCb finish_iteration_cb);
|
||||||
|
|
||||||
// Reinitialize count service and round kernel of this round after scaling operations are done.
|
// Reinitialize count service and round kernel of this round after scaling operations are done.
|
||||||
|
@ -48,7 +48,7 @@ class Round {
|
||||||
|
|
||||||
// This method is the callback which will be set to the communicator and called after the corresponding round message
|
// This method is the callback which will be set to the communicator and called after the corresponding round message
|
||||||
// is sent to the server.
|
// is sent to the server.
|
||||||
void LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message);
|
void LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// Round needs to be reset after each iteration is finished or its timer expires.
|
// Round needs to be reset after each iteration is finished or its timer expires.
|
||||||
void Reset();
|
void Reset();
|
||||||
|
@ -60,8 +60,8 @@ class Round {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The callbacks which will be set to DistributedCounterService.
|
// The callbacks which will be set to DistributedCounterService.
|
||||||
void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message);
|
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
std::string name_;
|
std::string name_;
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ class Round {
|
||||||
// Whether this round uses the server number as its threshold count.
|
// Whether this round uses the server number as its threshold count.
|
||||||
bool server_num_as_threshold_;
|
bool server_num_as_threshold_;
|
||||||
|
|
||||||
std::shared_ptr<core::CommunicatorBase> communicator_;
|
std::shared_ptr<ps::core::CommunicatorBase> communicator_;
|
||||||
|
|
||||||
// The round kernel for this Round.
|
// The round kernel for this Round.
|
||||||
std::shared_ptr<kernel::RoundKernel> kernel_;
|
std::shared_ptr<kernel::RoundKernel> kernel_;
|
||||||
|
@ -97,6 +97,6 @@ class Round {
|
||||||
FinalizeCb finalize_cb_;
|
FinalizeCb finalize_cb_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
|
||||||
|
|
|
@ -30,17 +30,8 @@
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
static std::vector<std::shared_ptr<core::CommunicatorBase>> global_worker_server_comms = {};
|
|
||||||
// This function is for the exit of server process when an interrupt signal is captured.
|
|
||||||
void SignalHandler(int signal) {
|
|
||||||
MS_LOG(INFO) << "Interrupt signal captured: " << signal;
|
|
||||||
std::for_each(global_worker_server_comms.begin(), global_worker_server_comms.end(),
|
|
||||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||||
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
|
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
@ -76,7 +67,6 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
|
||||||
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
|
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
|
||||||
// InitCipher---->InitExecutor
|
// InitCipher---->InitExecutor
|
||||||
void Server::Run() {
|
void Server::Run() {
|
||||||
signal(SIGINT, SignalHandler);
|
|
||||||
std::unique_lock<std::mutex> lock(scaling_mtx_);
|
std::unique_lock<std::mutex> lock(scaling_mtx_);
|
||||||
InitServerContext();
|
InitServerContext();
|
||||||
InitCluster();
|
InitCluster();
|
||||||
|
@ -84,8 +74,8 @@ void Server::Run() {
|
||||||
RegisterCommCallbacks();
|
RegisterCommCallbacks();
|
||||||
StartCommunicator();
|
StartCommunicator();
|
||||||
InitExecutor();
|
InitExecutor();
|
||||||
std::string encrypt_type = PSContext::instance()->encrypt_type();
|
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||||
if (encrypt_type != kNotEncryptType) {
|
if (encrypt_type != ps::kNotEncryptType) {
|
||||||
InitCipher();
|
InitCipher();
|
||||||
MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
|
MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
|
||||||
}
|
}
|
||||||
|
@ -96,7 +86,7 @@ void Server::Run() {
|
||||||
|
|
||||||
// Wait communicators to stop so the main thread is blocked.
|
// Wait communicators to stop so the main thread is blocked.
|
||||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Join(); });
|
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Join(); });
|
||||||
communicator_with_server_->Join();
|
communicator_with_server_->Join();
|
||||||
MsException::Instance().CheckException();
|
MsException::Instance().CheckException();
|
||||||
return;
|
return;
|
||||||
|
@ -115,18 +105,18 @@ void Server::CancelSafeMode() {
|
||||||
bool Server::IsSafeMode() { return safemode_.load(); }
|
bool Server::IsSafeMode() { return safemode_.load(); }
|
||||||
|
|
||||||
void Server::InitServerContext() {
|
void Server::InitServerContext() {
|
||||||
PSContext::instance()->GenerateResetterRound();
|
ps::PSContext::instance()->GenerateResetterRound();
|
||||||
scheduler_ip_ = PSContext::instance()->scheduler_host();
|
scheduler_ip_ = ps::PSContext::instance()->scheduler_host();
|
||||||
scheduler_port_ = PSContext::instance()->scheduler_port();
|
scheduler_port_ = ps::PSContext::instance()->scheduler_port();
|
||||||
worker_num_ = PSContext::instance()->initial_worker_num();
|
worker_num_ = ps::PSContext::instance()->initial_worker_num();
|
||||||
server_num_ = PSContext::instance()->initial_server_num();
|
server_num_ = ps::PSContext::instance()->initial_server_num();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::InitCluster() {
|
void Server::InitCluster() {
|
||||||
server_node_ = std::make_shared<core::ServerNode>();
|
server_node_ = std::make_shared<ps::core::ServerNode>();
|
||||||
MS_EXCEPTION_IF_NULL(server_node_);
|
MS_EXCEPTION_IF_NULL(server_node_);
|
||||||
task_executor_ = std::make_shared<core::TaskExecutor>(32);
|
task_executor_ = std::make_shared<ps::core::TaskExecutor>(32);
|
||||||
MS_EXCEPTION_IF_NULL(task_executor_);
|
MS_EXCEPTION_IF_NULL(task_executor_);
|
||||||
if (!InitCommunicatorWithServer()) {
|
if (!InitCommunicatorWithServer()) {
|
||||||
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
|
MS_LOG(EXCEPTION) << "Initializing cross-server communicator failed.";
|
||||||
|
@ -136,7 +126,6 @@ void Server::InitCluster() {
|
||||||
MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
|
MS_LOG(EXCEPTION) << "Initializing worker-server communicator failed.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
global_worker_server_comms = communicators_with_worker_;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,8 +176,8 @@ void Server::InitIteration() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
std::string encrypt_type = PSContext::instance()->encrypt_type();
|
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||||
if (encrypt_type == kPWEncryptType) {
|
if (encrypt_type == ps::kPWEncryptType) {
|
||||||
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
|
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
|
||||||
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
|
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
|
||||||
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
|
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
|
||||||
|
@ -245,10 +234,10 @@ void Server::InitCipher() {
|
||||||
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
|
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
|
||||||
int cipher_g = 1;
|
int cipher_g = 1;
|
||||||
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
|
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
|
||||||
float dp_eps = PSContext::instance()->dp_eps();
|
float dp_eps = ps::PSContext::instance()->dp_eps();
|
||||||
float dp_delta = PSContext::instance()->dp_delta();
|
float dp_delta = ps::PSContext::instance()->dp_delta();
|
||||||
float dp_norm_clip = PSContext::instance()->dp_norm_clip();
|
float dp_norm_clip = ps::PSContext::instance()->dp_norm_clip();
|
||||||
std::string encrypt_type = PSContext::instance()->encrypt_type();
|
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||||
|
|
||||||
mpz_t prim;
|
mpz_t prim;
|
||||||
mpz_init(prim);
|
mpz_init(prim);
|
||||||
|
@ -276,7 +265,7 @@ void Server::RegisterCommCallbacks() {
|
||||||
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
|
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
|
||||||
// rounds' callbacks.
|
// rounds' callbacks.
|
||||||
|
|
||||||
auto tcp_comm = std::dynamic_pointer_cast<core::TcpCommunicator>(communicator_with_server_);
|
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
||||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||||
|
|
||||||
// Set message callbacks for server-to-server communication.
|
// Set message callbacks for server-to-server communication.
|
||||||
|
@ -304,23 +293,23 @@ void Server::RegisterCommCallbacks() {
|
||||||
std::bind(&Server::ProcessAfterScalingIn, this));
|
std::bind(&Server::ProcessAfterScalingIn, this));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator) {
|
void Server::RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||||
MS_EXCEPTION_IF_NULL(communicator);
|
MS_EXCEPTION_IF_NULL(communicator);
|
||||||
communicator->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
|
communicator->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [&]() {
|
||||||
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
|
||||||
safemode_ = true;
|
safemode_ = true;
|
||||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||||
communicator_with_server_->Stop();
|
communicator_with_server_->Stop();
|
||||||
});
|
});
|
||||||
|
|
||||||
communicator->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [&]() {
|
communicator->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [&]() {
|
||||||
MS_LOG(ERROR)
|
MS_LOG(ERROR)
|
||||||
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
<< "Event NODE_TIMEOUT is captured. This is because some server nodes are finalized or crashed after the "
|
||||||
"network building phase.";
|
"network building phase.";
|
||||||
safemode_ = true;
|
safemode_ = true;
|
||||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||||
communicator_with_server_->Stop();
|
communicator_with_server_->Stop();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -377,7 +366,7 @@ void Server::StartCommunicator() {
|
||||||
|
|
||||||
MS_LOG(INFO) << "Start communicator with worker.";
|
MS_LOG(INFO) << "Start communicator with worker.";
|
||||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Start(); });
|
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Start(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::ProcessBeforeScalingOut() {
|
void Server::ProcessBeforeScalingOut() {
|
||||||
|
@ -424,7 +413,7 @@ void Server::ProcessAfterScalingIn() {
|
||||||
if (server_node_->rank_id() == UINT32_MAX) {
|
if (server_node_->rank_id() == UINT32_MAX) {
|
||||||
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
|
MS_LOG(WARNING) << "This server the one to be scaled in. Server exiting.";
|
||||||
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
std::for_each(communicators_with_worker_.begin(), communicators_with_worker_.end(),
|
||||||
[](const std::shared_ptr<core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
[](const std::shared_ptr<ps::core::CommunicatorBase> &communicator) { communicator->Stop(); });
|
||||||
communicator_with_server_->Stop();
|
communicator_with_server_->Stop();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -449,5 +438,5 @@ void Server::ProcessAfterScalingIn() {
|
||||||
safemode_ = false;
|
safemode_ = false;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
|
||||||
#define MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
#define MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -31,7 +31,7 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
|
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
|
||||||
class Server {
|
class Server {
|
||||||
|
@ -90,7 +90,7 @@ class Server {
|
||||||
void RegisterCommCallbacks();
|
void RegisterCommCallbacks();
|
||||||
|
|
||||||
// Register cluster exception callbacks. This method is called in RegisterCommCallbacks.
|
// Register cluster exception callbacks. This method is called in RegisterCommCallbacks.
|
||||||
void RegisterExceptionEventCallback(const std::shared_ptr<core::TcpCommunicator> &communicator);
|
void RegisterExceptionEventCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
||||||
|
|
||||||
// Initialize executor according to the server mode.
|
// Initialize executor according to the server mode.
|
||||||
void InitExecutor();
|
void InitExecutor();
|
||||||
|
@ -113,11 +113,11 @@ class Server {
|
||||||
void ProcessAfterScalingIn();
|
void ProcessAfterScalingIn();
|
||||||
|
|
||||||
// The server node is initialized in Server.
|
// The server node is initialized in Server.
|
||||||
std::shared_ptr<core::ServerNode> server_node_;
|
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||||
|
|
||||||
// The task executor of the communicators. This helps server to handle network message concurrently. The tasks
|
// The task executor of the communicators. This helps server to handle network message concurrently. The tasks
|
||||||
// submitted to this task executor is asynchronous.
|
// submitted to this task executor is asynchronous.
|
||||||
std::shared_ptr<core::TaskExecutor> task_executor_;
|
std::shared_ptr<ps::core::TaskExecutor> task_executor_;
|
||||||
|
|
||||||
// Which protocol should communicators use.
|
// Which protocol should communicators use.
|
||||||
bool use_tcp_;
|
bool use_tcp_;
|
||||||
|
@ -136,12 +136,12 @@ class Server {
|
||||||
|
|
||||||
// Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
|
// Server need a tcp communicator to communicate with other servers for counting, metadata storing, collective
|
||||||
// operations, etc.
|
// operations, etc.
|
||||||
std::shared_ptr<core::CommunicatorBase> communicator_with_server_;
|
std::shared_ptr<ps::core::CommunicatorBase> communicator_with_server_;
|
||||||
|
|
||||||
// The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
|
// The communication with workers(including mobile devices), has multiple protocol types: HTTP and TCP.
|
||||||
// In some cases, both types should be supported in one distributed training job. So here we may have multiple
|
// In some cases, both types should be supported in one distributed training job. So here we may have multiple
|
||||||
// communicators.
|
// communicators.
|
||||||
std::vector<std::shared_ptr<core::CommunicatorBase>> communicators_with_worker_;
|
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
|
||||||
|
|
||||||
// Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
|
// Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
|
||||||
std::mutex scaling_mtx_;
|
std::mutex scaling_mtx_;
|
||||||
|
@ -176,6 +176,6 @@ class Server {
|
||||||
float percent_for_get_model_;
|
float percent_for_get_model_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_SERVER_SERVER_H_
|
#endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_H_
|
||||||
|
|
|
@ -22,27 +22,27 @@
|
||||||
#include "utils/ms_exception.h"
|
#include "utils/ms_exception.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
namespace worker {
|
namespace worker {
|
||||||
void FLWorker::Run() {
|
void FLWorker::Run() {
|
||||||
worker_num_ = PSContext::instance()->worker_num();
|
worker_num_ = ps::PSContext::instance()->worker_num();
|
||||||
server_num_ = PSContext::instance()->server_num();
|
server_num_ = ps::PSContext::instance()->server_num();
|
||||||
scheduler_ip_ = PSContext::instance()->scheduler_ip();
|
scheduler_ip_ = ps::PSContext::instance()->scheduler_ip();
|
||||||
scheduler_port_ = PSContext::instance()->scheduler_port();
|
scheduler_port_ = ps::PSContext::instance()->scheduler_port();
|
||||||
worker_step_num_per_iteration_ = PSContext::instance()->worker_step_num_per_iteration();
|
worker_step_num_per_iteration_ = ps::PSContext::instance()->worker_step_num_per_iteration();
|
||||||
PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_;
|
ps::PSContext::instance()->cluster_config().scheduler_host = scheduler_ip_;
|
||||||
PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
|
ps::PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
|
||||||
PSContext::instance()->cluster_config().initial_worker_num = worker_num_;
|
ps::PSContext::instance()->cluster_config().initial_worker_num = worker_num_;
|
||||||
PSContext::instance()->cluster_config().initial_server_num = server_num_;
|
ps::PSContext::instance()->cluster_config().initial_server_num = server_num_;
|
||||||
MS_LOG(INFO) << "Initialize cluster config for worker. Worker number:" << worker_num_
|
MS_LOG(INFO) << "Initialize cluster config for worker. Worker number:" << worker_num_
|
||||||
<< ", Server number:" << server_num_ << ", Scheduler ip:" << scheduler_ip_
|
<< ", Server number:" << server_num_ << ", Scheduler ip:" << scheduler_ip_
|
||||||
<< ", Scheduler port:" << scheduler_port_
|
<< ", Scheduler port:" << scheduler_port_
|
||||||
<< ", Worker training step per iteration:" << worker_step_num_per_iteration_;
|
<< ", Worker training step per iteration:" << worker_step_num_per_iteration_;
|
||||||
|
|
||||||
worker_node_ = std::make_shared<core::WorkerNode>();
|
worker_node_ = std::make_shared<ps::core::WorkerNode>();
|
||||||
MS_EXCEPTION_IF_NULL(worker_node_);
|
MS_EXCEPTION_IF_NULL(worker_node_);
|
||||||
|
|
||||||
worker_node_->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
|
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
|
||||||
Finalize();
|
Finalize();
|
||||||
try {
|
try {
|
||||||
MS_LOG(EXCEPTION)
|
MS_LOG(EXCEPTION)
|
||||||
|
@ -51,7 +51,7 @@ void FLWorker::Run() {
|
||||||
MsException::Instance().SetException();
|
MsException::Instance().SetException();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
worker_node_->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() {
|
worker_node_->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
|
||||||
Finalize();
|
Finalize();
|
||||||
try {
|
try {
|
||||||
MS_LOG(EXCEPTION)
|
MS_LOG(EXCEPTION)
|
||||||
|
@ -74,7 +74,7 @@ void FLWorker::Finalize() {
|
||||||
worker_node_->Stop();
|
worker_node_->Stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, core::TcpUserCommand command,
|
bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
|
||||||
std::shared_ptr<std::vector<unsigned char>> *output) {
|
std::shared_ptr<std::vector<unsigned char>> *output) {
|
||||||
// If the worker is in safemode, do not communicate with server.
|
// If the worker is in safemode, do not communicate with server.
|
||||||
while (safemode_.load()) {
|
while (safemode_.load()) {
|
||||||
|
@ -97,7 +97,8 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
|
||||||
|
|
||||||
if (output != nullptr) {
|
if (output != nullptr) {
|
||||||
while (true) {
|
while (true) {
|
||||||
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command), output)) {
|
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command),
|
||||||
|
output)) {
|
||||||
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -106,7 +107,7 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == kClusterSafeMode) {
|
if (std::string(reinterpret_cast<char *>((*output)->data()), (*output)->size()) == ps::kClusterSafeMode) {
|
||||||
MS_LOG(INFO) << "The server " << server_rank << " is in safemode.";
|
MS_LOG(INFO) << "The server " << server_rank << " is in safemode.";
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode));
|
std::this_thread::sleep_for(std::chrono::milliseconds(kWorkerRetryDurationForSafeMode));
|
||||||
} else {
|
} else {
|
||||||
|
@ -114,7 +115,7 @@ bool FLWorker::SendToServer(uint32_t server_rank, const void *data, size_t size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!worker_node_->Send(core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
|
if (!worker_node_->Send(ps::core::NodeRole::SERVER, server_rank, message, size, static_cast<int>(command))) {
|
||||||
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
MS_LOG(ERROR) << "Sending message to server " << server_rank << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -155,9 +156,9 @@ void FLWorker::InitializeFollowerScaler() {
|
||||||
std::bind(&FLWorker::ProcessAfterScalingOut, this));
|
std::bind(&FLWorker::ProcessAfterScalingOut, this));
|
||||||
worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
|
worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
|
||||||
std::bind(&FLWorker::ProcessAfterScalingIn, this));
|
std::bind(&FLWorker::ProcessAfterScalingIn, this));
|
||||||
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationRunning),
|
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
|
||||||
std::bind(&FLWorker::HandleIterationRunningEvent, this));
|
std::bind(&FLWorker::HandleIterationRunningEvent, this));
|
||||||
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(CustomEvent::kIterationCompleted),
|
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
|
||||||
std::bind(&FLWorker::HandleIterationCompletedEvent, this));
|
std::bind(&FLWorker::HandleIterationCompletedEvent, this));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,5 +223,5 @@ void FLWorker::ProcessAfterScalingIn() {
|
||||||
safemode_ = false;
|
safemode_ = false;
|
||||||
}
|
}
|
||||||
} // namespace worker
|
} // namespace worker
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_
|
#ifndef MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
|
||||||
#define MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_
|
#define MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
@ -28,7 +28,7 @@
|
||||||
#include "ps/core/communicator/tcp_communicator.h"
|
#include "ps/core/communicator/tcp_communicator.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace fl {
|
||||||
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
||||||
|
|
||||||
// The step number for worker to judge whether to communicate with server.
|
// The step number for worker to judge whether to communicate with server.
|
||||||
|
@ -59,7 +59,7 @@ class FLWorker {
|
||||||
}
|
}
|
||||||
void Run();
|
void Run();
|
||||||
void Finalize();
|
void Finalize();
|
||||||
bool SendToServer(uint32_t server_rank, const void *data, size_t size, core::TcpUserCommand command,
|
bool SendToServer(uint32_t server_rank, const void *data, size_t size, ps::core::TcpUserCommand command,
|
||||||
std::shared_ptr<std::vector<unsigned char>> *output = nullptr);
|
std::shared_ptr<std::vector<unsigned char>> *output = nullptr);
|
||||||
|
|
||||||
uint32_t server_num() const;
|
uint32_t server_num() const;
|
||||||
|
@ -104,7 +104,7 @@ class FLWorker {
|
||||||
uint32_t worker_num_;
|
uint32_t worker_num_;
|
||||||
std::string scheduler_ip_;
|
std::string scheduler_ip_;
|
||||||
uint16_t scheduler_port_;
|
uint16_t scheduler_port_;
|
||||||
std::shared_ptr<core::WorkerNode> worker_node_;
|
std::shared_ptr<ps::core::WorkerNode> worker_node_;
|
||||||
|
|
||||||
// The worker standalone training step number before communicating with server. This used in hybrid training mode.
|
// The worker standalone training step number before communicating with server. This used in hybrid training mode.
|
||||||
uint64_t worker_step_num_per_iteration_;
|
uint64_t worker_step_num_per_iteration_;
|
||||||
|
@ -121,6 +121,6 @@ class FLWorker {
|
||||||
std::atomic_bool safemode_;
|
std::atomic_bool safemode_;
|
||||||
};
|
};
|
||||||
} // namespace worker
|
} // namespace worker
|
||||||
} // namespace ps
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_PS_WORKER_FL_WORKER_H_
|
#endif // MINDSPORE_CCSRC_FL_WORKER_FL_WORKER_H_
|
||||||
|
|
|
@ -639,7 +639,7 @@ bool StartPSWorkerAction(const ResourcePtr &res) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
bool StartFLWorkerAction(const ResourcePtr &) {
|
bool StartFLWorkerAction(const ResourcePtr &) {
|
||||||
ps::worker::FLWorker::GetInstance().Run();
|
fl::worker::FLWorker::GetInstance().Run();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -665,7 +665,7 @@ bool StartServerAction(const ResourcePtr &res) {
|
||||||
uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
|
uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
|
||||||
uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
|
uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
|
||||||
|
|
||||||
std::vector<ps::server::RoundConfig> rounds_config = {
|
std::vector<fl::server::RoundConfig> rounds_config = {
|
||||||
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
|
{"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
|
||||||
{"updateModel", true, update_model_time_window, true, update_model_threshold},
|
{"updateModel", true, update_model_time_window, true, update_model_threshold},
|
||||||
{"getModel"},
|
{"getModel"},
|
||||||
|
@ -676,22 +676,22 @@ bool StartServerAction(const ResourcePtr &res) {
|
||||||
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
|
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
|
||||||
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
|
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
|
||||||
|
|
||||||
ps::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
|
fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
|
||||||
|
|
||||||
size_t executor_threshold = 0;
|
size_t executor_threshold = 0;
|
||||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||||
executor_threshold = update_model_threshold;
|
executor_threshold = update_model_threshold;
|
||||||
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
|
fl::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
|
||||||
executor_threshold);
|
executor_threshold);
|
||||||
} else if (server_mode_ == ps::kServerModePS) {
|
} else if (server_mode_ == ps::kServerModePS) {
|
||||||
executor_threshold = worker_num;
|
executor_threshold = worker_num;
|
||||||
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
|
fl::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
|
||||||
executor_threshold);
|
executor_threshold);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
|
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ps::server::Server::GetInstance().Run();
|
fl::server::Server::GetInstance().Run();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1293,7 +1293,7 @@ void ClearResAtexit() {
|
||||||
MS_LOG(INFO) << "Start finalizing worker.";
|
MS_LOG(INFO) << "Start finalizing worker.";
|
||||||
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
||||||
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid)) {
|
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid)) {
|
||||||
ps::worker::FLWorker::GetInstance().Finalize();
|
fl::worker::FLWorker::GetInstance().Finalize();
|
||||||
} else {
|
} else {
|
||||||
ps::Worker::GetInstance().Finalize();
|
ps::Worker::GetInstance().Finalize();
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
package mindspore.ps;
|
package mindspore.fl;
|
||||||
|
|
||||||
message CollectiveData {
|
message CollectiveData {
|
||||||
bytes data = 1;
|
bytes data = 1;
|
||||||
|
|
|
@ -286,8 +286,9 @@ void PSContext::GenerateResetterRound() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
|
binary_server_context = ((unsigned int)is_parameter_server_mode << 0) |
|
||||||
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3);
|
((unsigned int)is_federated_learning_mode << 1) |
|
||||||
|
((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)secure_aggregation_ << 3);
|
||||||
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
|
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
|
||||||
resetter_round_ = ResetterRound::kNoNeedToReset;
|
resetter_round_ = ResetterRound::kNoNeedToReset;
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Reference in New Issue