diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc index aed4f569cac..bf058f53748 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc @@ -15,8 +15,10 @@ */ #include "fl/armour/cipher/cipher_init.h" -#include "fl/server/common.h" + #include "fl/armour/cipher/cipher_meta_storage.h" +#include "fl/server/common.h" +#include "fl/server/model_store.h" namespace mindspore { namespace armour { @@ -43,8 +45,7 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size publicparam_.t = param.t; secrets_minnums_ = param.t; client_num_need_ = cipher_initial_client_cnt; - featuremap_ = 1000; // todo: wait for other code - // merge.ps::server::DistributedMetadataStore::GetInstance().model_size() / sizeof(float); + featuremap_ = ps::server::ModelStore::GetInstance().model_size() / sizeof(float); share_clients_num_need_ = cipher_share_secrets_cnt; reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1; get_model_num_need_ = cipher_get_clientlist_cnt; diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc index e3474e14fc2..a0b2e9bf1e6 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc @@ -38,14 +38,14 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim std::string fl_id = get_exchange_keys_req->fl_id()->str(); if (find(clients.begin(), clients.end(), fl_id) == clients.end()) { - BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false); MS_LOG(INFO) << "The fl_id is not in clients."; + BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false); return false; } if (cur_clients_num < cipher_init_->client_num_need_) { - BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false); MS_LOG(INFO) << "The server is not ready yet: cur_clients_num < client_num_need"; - MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << "cur_clients_num : " << cipher_init_->client_num_need_; + MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << ", cur_clients_num : " << cipher_init_->client_num_need_; + BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false); return false; } @@ -55,7 +55,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim } bool flag = - BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, cur_iterator, next_req_time, true); + BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true); return flag; } // namespace armour @@ -95,17 +95,17 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re MS_LOG(INFO) << "client_num_need_ " << cipher_init_->client_num_need_ << ". cur_clients_num " << cur_clients_num; std::string fl_id = exchange_keys_req->fl_id()->str(); if (cur_clients_num >= cipher_init_->client_num_need_) { // the client num is enough, return false. + MS_LOG(ERROR) << "The server has received enough requests and refuse this request."; BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, "The server has received enough requests and refuse this request.", next_req_time, cur_iterator); - MS_LOG(ERROR) << "The server has received enough requests and refuse this request."; return false; } if (record_public_keys.find(fl_id) != record_public_keys.end()) { // the client already exists, return false. + MS_LOG(INFO) << "The server has received the request, please do not request again."; BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, "The server has received the request, please do not request again.", next_req_time, cur_iterator); - MS_LOG(INFO) << "The server has received the request, please do not request again."; return false; } @@ -135,9 +135,9 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re bool retcode_client = cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxExChangeKeysClientList, fl_id); if (retcode_key && retcode_client) { + MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success"; BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, "Success, but the server is not ready yet.", next_req_time, cur_iterator); - MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success"; return true; } else { MS_LOG(ERROR) << "update key or client failed"; @@ -164,7 +164,6 @@ void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr exc bool CipherKeys::BuildGetKeys(std::shared_ptr fbb, const schema::ResponseCode retcode, const int iteration, const std::string &next_req_time, bool is_good) { - schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get())); bool flag = true; if (is_good) { // convert client keys to standard keys list. @@ -176,6 +175,14 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr fbb, const MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size() << "clients num: " << cipher_init_->client_num_need_; flag = false; + auto fbs_next_req_time = fbb->CreateString(next_req_time); + schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get())); + rsp_buider.add_retcode(retcode); + rsp_buider.add_iteration(iteration); + rsp_buider.add_next_req_time(fbs_next_req_time); + auto rsp_get_keys = rsp_buider.Finish(); + + fbb->Finish(rsp_get_keys); } else { for (auto iter = record_public_keys.begin(); iter != record_public_keys.end(); ++iter) { // read (fl_id, c_pk, s_pk) from the map: record_public_keys_ @@ -190,18 +197,26 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr fbb, const public_keys_list.push_back(cur_public_key); } auto remote_publickeys = fbb->CreateVector(public_keys_list); + auto fbs_next_req_time = fbb->CreateString(next_req_time); + schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get())); + rsp_buider.add_retcode(retcode); + rsp_buider.add_iteration(iteration); rsp_buider.add_remote_publickeys(remote_publickeys); + rsp_buider.add_next_req_time(fbs_next_req_time); + auto rsp_get_keys = rsp_buider.Finish(); + fbb->Finish(rsp_get_keys); MS_LOG(INFO) << "CipherMgr::GetKeys Success"; - flag = true; } - } + } else { + auto fbs_next_req_time = fbb->CreateString(next_req_time); + schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get())); + rsp_buider.add_retcode(retcode); + rsp_buider.add_iteration(iteration); + rsp_buider.add_next_req_time(fbs_next_req_time); + auto rsp_get_keys = rsp_buider.Finish(); - auto fbs_next_req_time = fbb->CreateString(next_req_time); - rsp_buider.add_retcode(retcode); - rsp_buider.add_iteration(iteration); - rsp_buider.add_next_req_time(fbs_next_req_time); - auto rsp_get_keys = rsp_buider.Finish(); - fbb->Finish(rsp_get_keys); + fbb->Finish(rsp_get_keys); + } return flag; } diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc index f1839419eeb..37ab1e554ac 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc @@ -82,16 +82,19 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve return true; } -bool CipherMetaStorage::GetPrimeFromServer(const char *list_name, unsigned char *prime) { - const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); - auto &prime_list_pb = prime_pb_out.prime_list(); - if (prime_list_pb.prime_size() > 0 && prime_list_pb.prime(0).size() >= PRIME_MAX_LEN) { - for (int i = 0; i < PRIME_MAX_LEN; i++) { - prime[i] = static_cast(prime_list_pb.prime(0)[i]); - } +bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) { + const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name); + ps::Prime prime_pb(prime_pb_out.prime()); + std::string str = *(prime_pb.mutable_prime()); + MS_LOG(INFO) << "get prime from metastorage :" << str; + + if (str.size() != PRIME_MAX_LEN) { + MS_LOG(ERROR) << "get prime size is :" << str.size(); + return false; + } else { + memcpy_s(prime, PRIME_MAX_LEN, str.data(), PRIME_MAX_LEN); return true; } - return false; } bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) { @@ -104,6 +107,7 @@ bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::s return retcode; } void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) { + MS_LOG(INFO) << "register prime: " << prime; ps::Prime prime_id_pb; prime_id_pb.set_prime(prime); ps::PBMetadata prime_pb; diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h index 489d3874c66..685ec85c92c 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h @@ -65,7 +65,7 @@ class CipherMetaStorage { // Register Prime. void RegisterPrime(const char *list_name, const std::string &prime); // Get tprime from shared server. - bool GetPrimeFromServer(const char *list_name, unsigned char *prime); + bool GetPrimeFromServer(const char *prime_name, unsigned char *prime); // Get client shares from shared server. void GetClientSharesFromServer(const char *list_name, std::map> *clients_shares_list); diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc index 13a623c8d9e..91571e306ae 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_reconstruct.cc @@ -77,7 +77,7 @@ bool CipherReconStruct::CombineMask( std::vector noise(cipher_init_->featuremap_, 0.0); if (GetSuvNoise(clients_share_list, record_public_keys, fl_id, &noise, secret, length) == false) retcode = false; - client_keys->at(fl_id) = noise; + client_keys->insert(std::pair>(fl_id, noise)); MS_LOG(INFO) << " fl_id : " << fl_id; MS_LOG(INFO) << "end get complete s_uv."; } else { @@ -87,7 +87,7 @@ bool CipherReconStruct::CombineMask( for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) { noise[index_noise] *= -1; } - client_keys->at(fl_id) = noise; + client_keys->insert(std::pair>(fl_id, noise)); MS_LOG(INFO) << " fl_id : " << fl_id; } } diff --git a/mindspore/ccsrc/fl/server/common.h b/mindspore/ccsrc/fl/server/common.h index e95edabe554..83b28bf9f82 100644 --- a/mindspore/ccsrc/fl/server/common.h +++ b/mindspore/ccsrc/fl/server/common.h @@ -61,7 +61,7 @@ struct RoundConfig { struct CipherConfig { float share_secrets_ratio = 1.0; - float get_model_ratio = 1.0; + uint64_t cipher_time_window = 300000; size_t reconstruct_secrets_threshhold = 0; }; diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc index 7609874b609..09d1d6da024 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc @@ -243,11 +243,7 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P auto update_model_threshold = metadata_[name].mutable_update_model_threshold(); *update_model_threshold = meta.update_model_threshold(); } else if (meta.has_prime()) { - auto prime_list = metadata_[name].mutable_prime_list(); - auto &prime_id = meta.prime().prime(); - if (prime_list->prime_size() == 0) { - prime_list->add_prime(prime_id); - } + metadata_[name] = meta; } else if (meta.has_pair_client_keys()) { auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys(); auto &fl_id = meta.pair_client_keys().fl_id(); diff --git a/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc index be4852705d4..80884e76e16 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc @@ -109,7 +109,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient bool ClientListKernel::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) { std::shared_ptr fbb = std::make_shared(); - bool response = false; + // bool response = false; size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration; @@ -139,7 +139,8 @@ bool ClientListKernel::Launch(const std::vector &inputs, const std:: "GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list, std::to_string(CURRENT_TIME_MILLI.count()), iter_num); } else { - response = DealClient(iter_num, get_clients_req, fbb); + // response = DealClient(iter_num, get_clients_req, fbb); + DealClient(iter_num, get_clients_req, fbb); } } } @@ -148,7 +149,7 @@ bool ClientListKernel::Launch(const std::vector &inputs, const std:: clock_t end_time = clock(); double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); MS_LOG(INFO) << "client_list_kernel success time is : " << duration; - return response; + return true; } // namespace ps bool ClientListKernel::Reset() { diff --git a/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc index 8b75bee217a..623a4ab90cb 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc @@ -85,7 +85,10 @@ bool ExchangeKeysKernel::Launch(const std::vector &inputs, const std clock_t end_time = clock(); double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration; - return response; + if (!response) { + MS_LOG(INFO) << "ExchangeKeysKernel response is false."; + } + return true; } bool ExchangeKeysKernel::Reset() { diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc index 607645cd7f1..6fc0422632c 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc @@ -82,7 +82,10 @@ bool GetKeysKernel::Launch(const std::vector &inputs, const std::vec clock_t end_time = clock(); double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration; - return response; + if (!response) { + MS_LOG(INFO) << "GetKeysKernel response is false."; + } + return true; } bool GetKeysKernel::Reset() { diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc index 4f34249703e..4b1f5d5a4b3 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc @@ -48,7 +48,7 @@ bool GetSecretsKernel::Launch(const std::vector &inputs, const std:: MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count()); size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); - MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is " + MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is " << total_duration; clock_t start_time = clock(); @@ -84,7 +84,10 @@ bool GetSecretsKernel::Launch(const std::vector &inputs, const std:: clock_t end_time = clock(); double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration; - return response; + if (!response) { + MS_LOG(INFO) << "GetSecretsKernel response is false."; + } + return true; } bool GetSecretsKernel::Reset() { diff --git a/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc index 57ea9ac33bb..d36d7b4be28 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc @@ -76,7 +76,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector &inputs, con "Current amount for ReconstructSecretsKernel is enough.", iter_num, std::to_string(CURRENT_TIME_MILLI.count())); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return false; + return true; } void *req_data = inputs[0]->addr; @@ -94,7 +94,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector &inputs, con "update_model_client_num is zero.", iter_num, std::to_string(CURRENT_TIME_MILLI.count())); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return false; + return true; } const PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList); @@ -110,7 +110,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector &inputs, con "ReconstructSecretsKernel : client list is not ready", iter_num, std::to_string(CURRENT_TIME_MILLI.count())); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - return response; + return true; } for (int i = 0; i < client_list_pb.fl_id_size(); ++i) { client_list.push_back(client_list_pb.fl_id(i)); @@ -128,7 +128,10 @@ bool ReconstructSecretsKernel::Launch(const std::vector &inputs, con clock_t end_time = clock(); double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); MS_LOG(INFO) << "reconstruct_secrets_kernel success time is : " << duration; - return response; + if (!response) { + MS_LOG(INFO) << "reconstruct_secrets_kernel response is false."; + } + return true; } void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr &message) { diff --git a/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc index 4f183f106a4..9df01405783 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc @@ -85,7 +85,10 @@ bool ShareSecretsKernel::Launch(const std::vector &inputs, const std clock_t end_time = clock(); double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration; - return response; + if (!response) { + MS_LOG(INFO) << "share_secrets_kernel response is false."; + } + return true; } bool ShareSecretsKernel::Reset() { diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc index 549e3829af3..c0e29b49468 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc @@ -204,7 +204,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, float dp_eps = param->dp_eps; float dp_delta = param->dp_delta; float dp_norm_clip = param->dp_norm_clip; - auto encrypt_type = fbb->CreateString(param->encrypt_type); + auto encrypt_type = fbb->CreateString(PSContext::instance()->encrypt_type()); auto cipher_public_params = schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type); diff --git a/mindspore/ccsrc/fl/server/server.cc b/mindspore/ccsrc/fl/server/server.cc index cf3b7f52440..ca4d9f984a1 100644 --- a/mindspore/ccsrc/fl/server/server.cc +++ b/mindspore/ccsrc/fl/server/server.cc @@ -84,7 +84,11 @@ void Server::Run() { RegisterCommCallbacks(); StartCommunicator(); InitExecutor(); - InitCipher(); + std::string encrypt_type = PSContext::instance()->encrypt_type(); + if (encrypt_type != kNotEncryptType) { + InitCipher(); + MS_LOG(INFO) << "Parameters for secure aggregation have been initiated."; + } RegisterRoundKernel(); MS_LOG(INFO) << "Server started successfully."; safemode_ = false; @@ -182,40 +186,46 @@ void Server::InitIteration() { iteration_->AddRound(round); } - cipher_initial_client_cnt_ = rounds_config_[0].threshold_count; - cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0; - cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio; - cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count; - cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count; - cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold; - - MS_LOG(INFO) << "Initializing cipher:"; - MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_ - << " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_ - << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_; - MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_ - << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_ - << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_; - #ifdef ENABLE_ARMOUR - std::shared_ptr exchange_keys_round = - std::make_shared("exchangeKeys", false, 3000, true, cipher_exchange_secrets_cnt_); - iteration_->AddRound(exchange_keys_round); - std::shared_ptr get_keys_round = - std::make_shared("getKeys", false, 3000, true, cipher_exchange_secrets_cnt_); - iteration_->AddRound(get_keys_round); - std::shared_ptr share_secrets_round = - std::make_shared("shareSecrets", false, 3000, true, cipher_share_secrets_cnt_); - iteration_->AddRound(share_secrets_round); - std::shared_ptr get_secrets_round = - std::make_shared("getSecrets", false, 3000, true, cipher_share_secrets_cnt_); - iteration_->AddRound(get_secrets_round); - std::shared_ptr get_clientlist_round = - std::make_shared("getClientList", false, 3000, true, cipher_get_clientlist_cnt_); - iteration_->AddRound(get_clientlist_round); - std::shared_ptr reconstruct_secrets_round = - std::make_shared("reconstructSecrets", false, 3000, true, cipher_reconstruct_secrets_up_cnt_); - iteration_->AddRound(reconstruct_secrets_round); + std::string encrypt_type = PSContext::instance()->encrypt_type(); + if (encrypt_type == kPWEncryptType) { + cipher_initial_client_cnt_ = rounds_config_[0].threshold_count; + cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0; + cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio; + cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count; + cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count; + cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold; + cipher_time_window_ = cipher_config_.cipher_time_window; + + MS_LOG(INFO) << "Initializing cipher:"; + MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_ + << " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_ + << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_; + MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_ + << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_ + << " cipher_time_window_: " << cipher_time_window_ + << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_; + + std::shared_ptr exchange_keys_round = + std::make_shared("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_); + iteration_->AddRound(exchange_keys_round); + std::shared_ptr get_keys_round = + std::make_shared("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_); + iteration_->AddRound(get_keys_round); + std::shared_ptr share_secrets_round = + std::make_shared("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_); + iteration_->AddRound(share_secrets_round); + std::shared_ptr get_secrets_round = + std::make_shared("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_); + iteration_->AddRound(get_secrets_round); + std::shared_ptr get_clientlist_round = + std::make_shared("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_); + iteration_->AddRound(get_clientlist_round); + std::shared_ptr reconstruct_secrets_round = std::make_shared( + "reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_); + iteration_->AddRound(reconstruct_secrets_round); + MS_LOG(INFO) << "Cipher rounds has been added."; + } #endif // 2.Initialize all the rounds. diff --git a/mindspore/ccsrc/fl/server/server.h b/mindspore/ccsrc/fl/server/server.h index 1b1dfea5fb8..cc8da047efc 100644 --- a/mindspore/ccsrc/fl/server/server.h +++ b/mindspore/ccsrc/fl/server/server.h @@ -170,6 +170,7 @@ class Server { size_t cipher_get_clientlist_cnt_; size_t cipher_reconstruct_secrets_up_cnt_; size_t cipher_reconstruct_secrets_down_cnt_; + uint64_t cipher_time_window_; float percent_for_update_model_; float percent_for_get_model_; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 351f9a7d830..3c818d10aac 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -673,10 +673,10 @@ bool StartServerAction(const ResourcePtr &res) { {"pushWeight", false, 3000, true, server_num, true}}; float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio(); - float get_model_ratio = ps::PSContext::instance()->get_model_ratio(); + uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window(); size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold(); - ps::server::CipherConfig cipher_config = {share_secrets_ratio, get_model_ratio, reconstruct_secrets_threshhold}; + ps::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold}; size_t executor_threshold = 0; if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 1622adfc36c..c9bbf7c0f1f 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -371,8 +371,7 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_share_secrets_ratio", &PSContext::set_share_secrets_ratio, "Set threshold count ratio for share secrets round.") .def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.") - .def("set_get_model_ratio", &PSContext::set_get_model_ratio, "Set threshold count ratio for get model round.") - .def("get_model_ratio", &PSContext::get_model_ratio, "Get threshold count ratio for get model round.") + .def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.") .def("set_reconstruct_secrets_threshhold", &PSContext::set_reconstruct_secrets_threshhold, "Set threshold count for reconstruct secrets round.") .def("reconstruct_secrets_threshhold", &PSContext::reconstruct_secrets_threshhold, diff --git a/mindspore/ccsrc/ps/core/protos/fl.proto b/mindspore/ccsrc/ps/core/protos/fl.proto index b338e067e08..70dbe6efc54 100644 --- a/mindspore/ccsrc/ps/core/protos/fl.proto +++ b/mindspore/ccsrc/ps/core/protos/fl.proto @@ -105,10 +105,6 @@ message ClientNoises { OneClientNoises one_client_noises = 1; } -message PrimeList { - repeated bytes prime = 1; -} - message PairClientKeys { string fl_id = 1; KeysPb client_keys = 2; @@ -156,7 +152,6 @@ message PBMetadata { ClientNoises client_noises = 11; Prime prime = 12; - PrimeList prime_list = 13; } } diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 7b84450af38..18b81df612c 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -339,9 +339,14 @@ void PSContext::set_share_secrets_ratio(float share_secrets_ratio) { share_secre float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; } -void PSContext::set_get_model_ratio(float get_model_ratio) { get_model_ratio_ = get_model_ratio; } +void PSContext::set_cipher_time_window(uint64_t cipher_time_window) { + if (cipher_time_window_ < 0) { + MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0.."; + } + cipher_time_window_ = cipher_time_window; +} -float PSContext::get_model_ratio() const { return get_model_ratio_; } +uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; } void PSContext::set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold) { reconstruct_secrets_threshhold_ = reconstruct_secrets_threshhold; diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 7614def1a96..124c6bf6ee2 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -131,8 +131,8 @@ class PSContext { void set_share_secrets_ratio(float share_secrets_ratio); float share_secrets_ratio() const; - void set_get_model_ratio(float get_model_ratio); - float get_model_ratio() const; + void set_cipher_time_window(uint64_t cipher_time_window); + uint64_t cipher_time_window() const; void set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold); uint64_t reconstruct_secrets_threshhold() const; @@ -201,7 +201,7 @@ class PSContext { update_model_ratio_(1.0), update_model_time_window_(3000), share_secrets_ratio_(1.0), - get_model_ratio_(1.0), + cipher_time_window_(300000), reconstruct_secrets_threshhold_(2000), fl_iteration_num_(20), client_epoch_num_(25), @@ -260,8 +260,8 @@ class PSContext { // Share model threshold is a certain ratio of share secrets threshold which is set as share_secrets_ratio_. float share_secrets_ratio_; - // Get model threshold is a certain ratio of get model threshold which is set as get_model_ratio_. - float get_model_ratio_; + // The time window of each cipher round in millisecond. + uint64_t cipher_time_window_; // The threshold count of reconstruct secrets round. Used in federated learning for now. uint64_t reconstruct_secrets_threshhold_; diff --git a/mindspore/context.py b/mindspore/context.py index 833f6cf3311..0ed54a3652e 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -840,7 +840,7 @@ def set_fl_context(**kwargs): start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000. share_secrets_ratio (float): The ratio for computing the threshold count of share secrets. Default: 1.0. update_model_ratio (float): The ratio for computing the threshold count of updateModel. Default: 1.0. - get_model_ratio (float): The ratio for computing the threshold count of get model. Default: 1.0. + cipher_time_window (int): The time window duration for each cipher round in millisecond. Default: 300000. reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 0. update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000. fl_name (string): The federated learning job name. Default: ''. diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index cfd119b634c..6a9f7897a59 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -58,7 +58,7 @@ _set_ps_context_func_map = { "update_model_ratio": ps_context().set_update_model_ratio, "update_model_time_window": ps_context().set_update_model_time_window, "share_secrets_ratio": ps_context().set_share_secrets_ratio, - "get_model_ratio": ps_context().set_get_model_ratio, + "cipher_time_window": ps_context().set_cipher_time_window, "reconstruct_secrets_threshhold": ps_context().set_reconstruct_secrets_threshhold, "fl_name": ps_context().set_fl_name, "fl_iteration_num": ps_context().set_fl_iteration_num, @@ -91,7 +91,7 @@ _get_ps_context_func_map = { "update_model_ratio": ps_context().update_model_ratio, "update_model_time_window": ps_context().update_model_time_window, "share_secrets_ratio": ps_context().share_secrets_ratio, - "get_model_ratio": ps_context().get_model_ratio, + "cipher_time_window": ps_context().set_cipher_time_window, "reconstruct_secrets_threshhold": ps_context().reconstruct_secrets_threshhold, "fl_name": ps_context().fl_name, "fl_iteration_num": ps_context().fl_iteration_num, diff --git a/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py b/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py index c8872692f83..9ea54d3030d 100644 --- a/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py +++ b/tests/st/fl/hybrid_lenet/run_hybrid_train_server.py @@ -35,6 +35,15 @@ parser.add_argument("--client_batch_size", type=int, default=32) parser.add_argument("--client_learning_rate", type=float, default=0.1) parser.add_argument("--local_server_num", type=int, default=-1) parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +# parameters for encrypt_type='DP_ENCRYPT' +parser.add_argument("--dp_eps", type=float, default=50.0) +parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num +parser.add_argument("--dp_norm_clip", type=float, default=1.0) +# parameters for encrypt_type='PW_ENCRYPT' +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--cipher_time_window", type=int, default=300000) +parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3) args, _ = parser.parse_known_args() device_target = args.device_target @@ -55,6 +64,13 @@ client_batch_size = args.client_batch_size client_learning_rate = args.client_learning_rate local_server_num = args.local_server_num config_file_path = args.config_file_path +encrypt_type = args.encrypt_type +share_secrets_ratio = args.share_secrets_ratio +cipher_time_window = args.cipher_time_window +reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold +dp_eps = args.dp_eps +dp_delta = args.dp_delta +dp_norm_clip = args.dp_norm_clip if local_server_num == -1: local_server_num = server_num @@ -85,6 +101,13 @@ for i in range(local_server_num): cmd_server += " --client_epoch_num=" + str(client_epoch_num) cmd_server += " --client_batch_size=" + str(client_batch_size) cmd_server += " --client_learning_rate=" + str(client_learning_rate) + cmd_server += " --encrypt_type=" + str(encrypt_type) + cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) + cmd_server += " --cipher_time_window=" + str(cipher_time_window) + cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold) + cmd_server += " --dp_eps=" + str(dp_eps) + cmd_server += " --dp_delta=" + str(dp_delta) + cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) cmd_server += " > server.log 2>&1 &" import time diff --git a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py index f7bea5d1527..188916daf3d 100644 --- a/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py +++ b/tests/st/fl/hybrid_lenet/test_hybrid_train_lenet.py @@ -45,6 +45,15 @@ parser.add_argument("--client_learning_rate", type=float, default=0.1) parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) parser.add_argument("--scheduler_manage_port", type=int, default=11202) parser.add_argument("--config_file_path", type=str, default="") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +# parameters for encrypt_type='DP_ENCRYPT' +parser.add_argument("--dp_eps", type=float, default=50.0) +parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num +parser.add_argument("--dp_norm_clip", type=float, default=1.0) +# parameters for encrypt_type='PW_ENCRYPT' +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--cipher_time_window", type=int, default=300000) +parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3) args, _ = parser.parse_known_args() device_target = args.device_target @@ -67,6 +76,13 @@ client_learning_rate = args.client_learning_rate worker_step_num_per_iteration = args.worker_step_num_per_iteration scheduler_manage_port = args.scheduler_manage_port config_file_path = args.config_file_path +encrypt_type = args.encrypt_type +share_secrets_ratio = args.share_secrets_ratio +cipher_time_window = args.cipher_time_window +reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold +dp_eps = args.dp_eps +dp_delta = args.dp_delta +dp_norm_clip = args.dp_norm_clip ctx = { "enable_fl": True, @@ -88,7 +104,14 @@ ctx = { "client_learning_rate": client_learning_rate, "worker_step_num_per_iteration": worker_step_num_per_iteration, "scheduler_manage_port": scheduler_manage_port, - "config_file_path": config_file_path + "config_file_path": config_file_path, + "share_secrets_ratio": share_secrets_ratio, + "cipher_time_window": cipher_time_window, + "reconstruct_secrets_threshhold": reconstruct_secrets_threshhold, + "dp_eps": dp_eps, + "dp_delta": dp_delta, + "dp_norm_clip": dp_norm_clip, + "encrypt_type": encrypt_type } context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False) diff --git a/tests/st/fl/mobile/run_mobile_server.py b/tests/st/fl/mobile/run_mobile_server.py index 2b5bdbce6b5..26fd7f0cd5b 100644 --- a/tests/st/fl/mobile/run_mobile_server.py +++ b/tests/st/fl/mobile/run_mobile_server.py @@ -28,9 +28,6 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1) parser.add_argument("--start_fl_job_time_window", type=int, default=3000) parser.add_argument("--update_model_ratio", type=float, default=1.0) parser.add_argument("--update_model_time_window", type=int, default=3000) -parser.add_argument("--share_secrets_ratio", type=float, default=1.0) -parser.add_argument("--get_model_ratio", type=float, default=1.0) -parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0) parser.add_argument("--fl_name", type=str, default="Lenet") parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) @@ -43,6 +40,10 @@ parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num parser.add_argument("--dp_norm_clip", type=float, default=1.0) +# parameters for encrypt_type='PW_ENCRYPT' +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--cipher_time_window", type=int, default=300000) +parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3) if __name__ == "__main__": args, _ = parser.parse_known_args() @@ -58,7 +59,7 @@ if __name__ == "__main__": update_model_ratio = args.update_model_ratio update_model_time_window = args.update_model_time_window share_secrets_ratio = args.share_secrets_ratio - get_model_ratio = args.get_model_ratio + cipher_time_window = args.cipher_time_window reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold fl_name = args.fl_name fl_iteration_num = args.fl_iteration_num @@ -96,7 +97,7 @@ if __name__ == "__main__": cmd_server += " --update_model_ratio=" + str(update_model_ratio) cmd_server += " --update_model_time_window=" + str(update_model_time_window) cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) - cmd_server += " --get_model_ratio=" + str(get_model_ratio) + cmd_server += " --cipher_time_window=" + str(cipher_time_window) cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold) cmd_server += " --fl_name=" + fl_name cmd_server += " --fl_iteration_num=" + str(fl_iteration_num) diff --git a/tests/st/fl/mobile/test_mobile_lenet.py b/tests/st/fl/mobile/test_mobile_lenet.py index 345f52dd13b..a6fded1ccec 100644 --- a/tests/st/fl/mobile/test_mobile_lenet.py +++ b/tests/st/fl/mobile/test_mobile_lenet.py @@ -36,9 +36,6 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1) parser.add_argument("--start_fl_job_time_window", type=int, default=3000) parser.add_argument("--update_model_ratio", type=float, default=1.0) parser.add_argument("--update_model_time_window", type=int, default=3000) -parser.add_argument("--share_secrets_ratio", type=float, default=1.0) -parser.add_argument("--get_model_ratio", type=float, default=1.0) -parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0) parser.add_argument("--fl_name", type=str, default="Lenet") parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) @@ -51,6 +48,10 @@ parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") parser.add_argument("--dp_eps", type=float, default=50.0) parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num parser.add_argument("--dp_norm_clip", type=float, default=1.0) +# parameters for encrypt_type='PW_ENCRYPT' +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--cipher_time_window", type=int, default=300000) +parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3) args, _ = parser.parse_known_args() device_target = args.device_target @@ -66,7 +67,7 @@ start_fl_job_time_window = args.start_fl_job_time_window update_model_ratio = args.update_model_ratio update_model_time_window = args.update_model_time_window share_secrets_ratio = args.share_secrets_ratio -get_model_ratio = args.get_model_ratio +cipher_time_window = args.cipher_time_window reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold fl_name = args.fl_name fl_iteration_num = args.fl_iteration_num @@ -94,7 +95,7 @@ ctx = { "update_model_ratio": update_model_ratio, "update_model_time_window": update_model_time_window, "share_secrets_ratio": share_secrets_ratio, - "get_model_ratio": get_model_ratio, + "cipher_time_window": cipher_time_window, "reconstruct_secrets_threshhold": reconstruct_secrets_threshhold, "fl_name": fl_name, "fl_iteration_num": fl_iteration_num,