!31197 fix I4VPZ5
Merge pull request !31197 from tan-wei-cheng-3260/develop-twc-master
This commit is contained in:
commit
a6454e02e4
|
@ -75,6 +75,7 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
||||||
if (param.encrypt_type == mindspore::ps::kPWEncryptType) {
|
if (param.encrypt_type == mindspore::ps::kPWEncryptType) {
|
||||||
cipher_meta_storage_.RegisterClass();
|
cipher_meta_storage_.RegisterClass();
|
||||||
const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN);
|
const std::string new_prime(reinterpret_cast<const char *>(param.prime), PRIME_MAX_LEN);
|
||||||
|
new_prime_ = new_prime;
|
||||||
cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime);
|
cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime);
|
||||||
if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) {
|
if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) {
|
||||||
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
|
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
|
||||||
|
@ -102,6 +103,19 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CipherInit::ReInitForScaling() {
|
||||||
|
if (ps::PSContext::instance()->encrypt_type() == mindspore::ps::kPWEncryptType) {
|
||||||
|
cipher_meta_storage_.RegisterClass();
|
||||||
|
cipher_meta_storage_.RegisterPrime(fl::server::kCtxCipherPrimer, new_prime_);
|
||||||
|
if (!cipher_meta_storage_.GetPrimeFromServer(fl::server::kCtxCipherPrimer, publicparam_.prime)) {
|
||||||
|
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "CipherInit reinit for scaling success.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool CipherInit::Check_Parames() {
|
bool CipherInit::Check_Parames() {
|
||||||
MS_LOG(INFO) << "Check cipher params:";
|
MS_LOG(INFO) << "Check cipher params:";
|
||||||
if (featuremap_ < 1) {
|
if (featuremap_ < 1) {
|
||||||
|
|
|
@ -43,6 +43,8 @@ class CipherInit {
|
||||||
size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt, size_t cipher_get_list_sign_cnt,
|
size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt, size_t cipher_get_list_sign_cnt,
|
||||||
size_t cipher_clients_threshold_for_reconstruct);
|
size_t cipher_clients_threshold_for_reconstruct);
|
||||||
|
|
||||||
|
bool ReInitForScaling();
|
||||||
|
|
||||||
// Get public params. which is given to start fl job thread.
|
// Get public params. which is given to start fl job thread.
|
||||||
CipherPublicPara *GetPublicParams() { return &publicparam_; }
|
CipherPublicPara *GetPublicParams() { return &publicparam_; }
|
||||||
|
|
||||||
|
@ -72,6 +74,8 @@ class CipherInit {
|
||||||
|
|
||||||
// Check whether the parameters are valid.
|
// Check whether the parameters are valid.
|
||||||
bool Check_Parames();
|
bool Check_Parames();
|
||||||
|
|
||||||
|
std::string new_prime_;
|
||||||
};
|
};
|
||||||
} // namespace armour
|
} // namespace armour
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -170,10 +170,10 @@ bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &cl
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(ERROR) << "recombined shares";
|
MS_LOG(INFO) << "recombined shares";
|
||||||
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
|
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
|
||||||
MS_LOG(ERROR) << "fl_id: " << iter->first;
|
MS_LOG(INFO) << "fl_id: " << iter->first;
|
||||||
MS_LOG(ERROR) << "share size: " << iter->second.size();
|
MS_LOG(INFO) << "share size: " << iter->second.size();
|
||||||
}
|
}
|
||||||
std::vector<Share *> shares_tmp;
|
std::vector<Share *> shares_tmp;
|
||||||
if (!MallocShares(&shares_tmp, (SizeToInt)(cipher_init_->secrets_minnums_))) {
|
if (!MallocShares(&shares_tmp, (SizeToInt)(cipher_init_->secrets_minnums_))) {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <csignal>
|
#include <csignal>
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
#include "fl/armour/secure_protocol/secret_sharing.h"
|
#include "fl/armour/secure_protocol/secret_sharing.h"
|
||||||
|
#include "fl/armour/cipher/cipher_init.h"
|
||||||
#endif
|
#endif
|
||||||
#include "fl/server/round.h"
|
#include "fl/server/round.h"
|
||||||
#include "fl/server/model_store.h"
|
#include "fl/server/model_store.h"
|
||||||
|
@ -537,6 +538,11 @@ void Server::ProcessAfterScalingOut() {
|
||||||
if (!Executor::GetInstance().ReInitForScaling()) {
|
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||||
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
||||||
}
|
}
|
||||||
|
#ifdef ENABLE_ARMOUR
|
||||||
|
if (!armour::CipherInit::GetInstance().ReInitForScaling()) {
|
||||||
|
MS_LOG(WARNING) << "CipherInit reinitializing failed.";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
|
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
|
||||||
safemode_ = false;
|
safemode_ = false;
|
||||||
}
|
}
|
||||||
|
@ -565,6 +571,11 @@ void Server::ProcessAfterScalingIn() {
|
||||||
if (!Executor::GetInstance().ReInitForScaling()) {
|
if (!Executor::GetInstance().ReInitForScaling()) {
|
||||||
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
MS_LOG(WARNING) << "Executor reinitializing failed.";
|
||||||
}
|
}
|
||||||
|
#ifdef ENABLE_ARMOUR
|
||||||
|
if (!armour::CipherInit::GetInstance().ReInitForScaling()) {
|
||||||
|
MS_LOG(WARNING) << "CipherInit reinitializing failed.";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
|
std::this_thread::sleep_for(std::chrono::milliseconds(kServerSleepTimeForNetworking));
|
||||||
safemode_ = false;
|
safemode_ = false;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue