Merge pull request !31197 from tan-wei-cheng-3260/develop-twc-master
This commit is contained in:
i-robot 2022-03-14 06:45:50 +00:00 committed by Gitee
commit a6454e02e4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 32 additions and 3 deletions

View File

@ -75,6 +75,7 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
if (param.encrypt_type == mindspore::ps::kPWEncryptType) { if (param.encrypt_type == mindspore::ps::kPWEncryptType) {
cipher_meta_storage_.RegisterClass(); cipher_meta_storage_.RegisterClass();
const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN); const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN);
new_prime_ = new_prime;
cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime); cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime);
if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) { if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) {
MS_LOG(ERROR) << "Cipher Param Update is invalid."; MS_LOG(ERROR) << "Cipher Param Update is invalid.";
@ -102,6 +103,19 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
return true; return true;
} }
bool CipherInit::ReInitForScaling() {
if (ps::PSContext::instance()->encrypt_type() == mindspore::ps::kPWEncryptType) {
cipher_meta_storage_.RegisterClass();
cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime_);
if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) {
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
return false;
}
}
MS_LOG(INFO) << "CipherInit reinit for scaling success.";
return true;
}
bool CipherInit::Check_Parames() { bool CipherInit::Check_Parames() {
MS_LOG(INFO) << "Check cipher params:"; MS_LOG(INFO) << "Check cipher params:";
if (featuremap_ < 1) { if (featuremap_ < 1) {

View File

@ -43,6 +43,8 @@ class CipherInit {
size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt, size_t cipher_get_list_sign_cnt, size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt, size_t cipher_get_list_sign_cnt,
size_t cipher_clients_threshold_for_reconstruct); size_t cipher_clients_threshold_for_reconstruct);
bool ReInitForScaling();
// Get public params. which is given to start fl job thread. // Get public params. which is given to start fl job thread.
CipherPublicPara *GetPublicParams() { return &publicparam_; } CipherPublicPara *GetPublicParams() { return &publicparam_; }
@ -72,6 +74,8 @@ class CipherInit {
// Check whether the parameters are valid. // Check whether the parameters are valid.
bool Check_Parames(); bool Check_Parames();
std::string new_prime_;
}; };
} // namespace armour } // namespace armour
} // namespace mindspore } // namespace mindspore

View File

@ -170,10 +170,10 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
return false; return false;
} }
MS_LOG(ERROR) << "recombined shares"; MS_LOG(INFO) << "recombined shares";
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) { for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
MS_LOG(ERROR) << "fl_id: " << iter->first; MS_LOG(INFO) << "fl_id: " << iter->first;
MS_LOG(ERROR) << "share size: " << iter->second.size(); MS_LOG(INFO) << "share size: " << iter->second.size();
} }
std::vector<Share *> shares_tmp; std::vector<Share *> shares_tmp;
if (!MallocShares(&shares_tmp, (SizeToInt)(cipher_init_->secrets_minnums_))) { if (!MallocShares(&shares_tmp, (SizeToInt)(cipher_init_->secrets_minnums_))) {

View File

@ -20,6 +20,7 @@
#include <csignal> #include <csignal>
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
#include "fl/armour/secure_protocol/secret_sharing.h" #include "fl/armour/secure_protocol/secret_sharing.h"
#include "fl/armour/cipher/cipher_init.h"
#endif #endif
#include "fl/server/round.h" #include "fl/server/round.h"
#include "fl/server/model_store.h" #include "fl/server/model_store.h"
@ -537,6 +538,11 @@ void Server::ProcessAfterScalingOut() {
if (!Executor::GetInstance().ReInitForScaling()) { if (!Executor::GetInstance().ReInitForScaling()) {
MS_LOG(WARNING) << "Executor reinitializing failed."; MS_LOG(WARNING) << "Executor reinitializing failed.";
} }
#ifdef ENABLE_ARMOUR
if (!armour::CipherInit::GetInstance().ReInitForScaling()) {
MS_LOG(WARNING) << "CipherInit reinitializing failed.";
}
#endif
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
safemode_ = false; safemode_ = false;
} }
@ -565,6 +571,11 @@ void Server::ProcessAfterScalingIn() {
if (!Executor::GetInstance().ReInitForScaling()) { if (!Executor::GetInstance().ReInitForScaling()) {
MS_LOG(WARNING) << "Executor reinitializing failed."; MS_LOG(WARNING) << "Executor reinitializing failed.";
} }
#ifdef ENABLE_ARMOUR
if (!armour::CipherInit::GetInstance().ReInitForScaling()) {
MS_LOG(WARNING) << "CipherInit reinitializing failed.";
}
#endif
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking)); std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
safemode_ = false; safemode_ = false;
} }