forked from mindspore-Ecosystem/mindspore
Add secure parameters for mindspore federated learning.
fix prime initialization and build key bug fix reconstruct access bug and kernels' retcode. Fix init issue Add fl secure parameter cipher_time_window
This commit is contained in:
parent
e372634d16
commit
d2b42fd12a
|
@ -15,8 +15,10 @@
|
|||
*/
|
||||
|
||||
#include "fl/armour/cipher/cipher_init.h"
|
||||
#include "fl/server/common.h"
|
||||
|
||||
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/model_store.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
@ -43,8 +45,7 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
|||
publicparam_.t = param.t;
|
||||
secrets_minnums_ = param.t;
|
||||
client_num_need_ = cipher_initial_client_cnt;
|
||||
featuremap_ = 1000; // todo: wait for other code
|
||||
// merge.ps::server::DistributedMetadataStore::GetInstance().model_size() / sizeof(float);
|
||||
featuremap_ = ps::server::ModelStore::GetInstance().model_size() / sizeof(float);
|
||||
share_clients_num_need_ = cipher_share_secrets_cnt;
|
||||
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
|
||||
get_model_num_need_ = cipher_get_clientlist_cnt;
|
||||
|
|
|
@ -38,14 +38,14 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
|
|||
std::string fl_id = get_exchange_keys_req->fl_id()->str();
|
||||
|
||||
if (find(clients.begin(), clients.end(), fl_id) == clients.end()) {
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
|
||||
MS_LOG(INFO) << "The fl_id is not in clients.";
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
|
||||
return false;
|
||||
}
|
||||
if (cur_clients_num < cipher_init_->client_num_need_) {
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
|
||||
MS_LOG(INFO) << "The server is not ready yet: cur_clients_num < client_num_need";
|
||||
MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << "cur_clients_num : " << cipher_init_->client_num_need_;
|
||||
MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << ", cur_clients_num : " << cipher_init_->client_num_need_;
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
|
|||
}
|
||||
|
||||
bool flag =
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, cur_iterator, next_req_time, true);
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
|
||||
return flag;
|
||||
} // namespace armour
|
||||
|
||||
|
@ -95,17 +95,17 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
|
|||
MS_LOG(INFO) << "client_num_need_ " << cipher_init_->client_num_need_ << ". cur_clients_num " << cur_clients_num;
|
||||
std::string fl_id = exchange_keys_req->fl_id()->str();
|
||||
if (cur_clients_num >= cipher_init_->client_num_need_) { // the client num is enough, return false.
|
||||
MS_LOG(ERROR) << "The server has received enough requests and refuse this request.";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
"The server has received enough requests and refuse this request.", next_req_time,
|
||||
cur_iterator);
|
||||
MS_LOG(ERROR) << "The server has received enough requests and refuse this request.";
|
||||
return false;
|
||||
}
|
||||
if (record_public_keys.find(fl_id) != record_public_keys.end()) { // the client already exists, return false.
|
||||
MS_LOG(INFO) << "The server has received the request, please do not request again.";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
"The server has received the request, please do not request again.", next_req_time,
|
||||
cur_iterator);
|
||||
MS_LOG(INFO) << "The server has received the request, please do not request again.";
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -135,9 +135,9 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
|
|||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxExChangeKeysClientList, fl_id);
|
||||
if (retcode_key && retcode_client) {
|
||||
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
"Success, but the server is not ready yet.", next_req_time, cur_iterator);
|
||||
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
|
||||
return true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "update key or client failed";
|
||||
|
@ -164,7 +164,6 @@ void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exc
|
|||
|
||||
bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const int iteration, const std::string &next_req_time, bool is_good) {
|
||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||
bool flag = true;
|
||||
if (is_good) {
|
||||
// convert client keys to standard keys list.
|
||||
|
@ -176,6 +175,14 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
|
|||
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
|
||||
<< "clients num: " << cipher_init_->client_num_need_;
|
||||
flag = false;
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||
rsp_buider.add_retcode(retcode);
|
||||
rsp_buider.add_iteration(iteration);
|
||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||
auto rsp_get_keys = rsp_buider.Finish();
|
||||
|
||||
fbb->Finish(rsp_get_keys);
|
||||
} else {
|
||||
for (auto iter = record_public_keys.begin(); iter != record_public_keys.end(); ++iter) {
|
||||
// read (fl_id, c_pk, s_pk) from the map: record_public_keys_
|
||||
|
@ -190,18 +197,26 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
|
|||
public_keys_list.push_back(cur_public_key);
|
||||
}
|
||||
auto remote_publickeys = fbb->CreateVector(public_keys_list);
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||
rsp_buider.add_retcode(retcode);
|
||||
rsp_buider.add_iteration(iteration);
|
||||
rsp_buider.add_remote_publickeys(remote_publickeys);
|
||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||
auto rsp_get_keys = rsp_buider.Finish();
|
||||
fbb->Finish(rsp_get_keys);
|
||||
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
|
||||
flag = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||
rsp_buider.add_retcode(retcode);
|
||||
rsp_buider.add_iteration(iteration);
|
||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||
auto rsp_get_keys = rsp_buider.Finish();
|
||||
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
rsp_buider.add_retcode(retcode);
|
||||
rsp_buider.add_iteration(iteration);
|
||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||
auto rsp_get_keys = rsp_buider.Finish();
|
||||
fbb->Finish(rsp_get_keys);
|
||||
fbb->Finish(rsp_get_keys);
|
||||
}
|
||||
return flag;
|
||||
}
|
||||
|
||||
|
|
|
@ -82,16 +82,19 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::GetPrimeFromServer(const char *list_name, unsigned char *prime) {
|
||||
const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
auto &prime_list_pb = prime_pb_out.prime_list();
|
||||
if (prime_list_pb.prime_size() > 0 && prime_list_pb.prime(0).size() >= PRIME_MAX_LEN) {
|
||||
for (int i = 0; i < PRIME_MAX_LEN; i++) {
|
||||
prime[i] = static_cast<unsigned char>(prime_list_pb.prime(0)[i]);
|
||||
}
|
||||
bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) {
|
||||
const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
|
||||
ps::Prime prime_pb(prime_pb_out.prime());
|
||||
std::string str = *(prime_pb.mutable_prime());
|
||||
MS_LOG(INFO) << "get prime from metastorage :" << str;
|
||||
|
||||
if (str.size() != PRIME_MAX_LEN) {
|
||||
MS_LOG(ERROR) << "get prime size is :" << str.size();
|
||||
return false;
|
||||
} else {
|
||||
memcpy_s(prime, PRIME_MAX_LEN, str.data(), PRIME_MAX_LEN);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
|
||||
|
@ -104,6 +107,7 @@ bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::s
|
|||
return retcode;
|
||||
}
|
||||
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
|
||||
MS_LOG(INFO) << "register prime: " << prime;
|
||||
ps::Prime prime_id_pb;
|
||||
prime_id_pb.set_prime(prime);
|
||||
ps::PBMetadata prime_pb;
|
||||
|
|
|
@ -65,7 +65,7 @@ class CipherMetaStorage {
|
|||
// Register Prime.
|
||||
void RegisterPrime(const char *list_name, const std::string &prime);
|
||||
// Get tprime from shared server.
|
||||
bool GetPrimeFromServer(const char *list_name, unsigned char *prime);
|
||||
bool GetPrimeFromServer(const char *prime_name, unsigned char *prime);
|
||||
// Get client shares from shared server.
|
||||
void GetClientSharesFromServer(const char *list_name,
|
||||
std::map<std::string, std::vector<clientshare_str>> *clients_shares_list);
|
||||
|
|
|
@ -77,7 +77,7 @@ bool CipherReconStruct::CombineMask(
|
|||
std::vector<float> noise(cipher_init_->featuremap_, 0.0);
|
||||
if (GetSuvNoise(clients_share_list, record_public_keys, fl_id, &noise, secret, length) == false)
|
||||
retcode = false;
|
||||
client_keys->at(fl_id) = noise;
|
||||
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
||||
MS_LOG(INFO) << " fl_id : " << fl_id;
|
||||
MS_LOG(INFO) << "end get complete s_uv.";
|
||||
} else {
|
||||
|
@ -87,7 +87,7 @@ bool CipherReconStruct::CombineMask(
|
|||
for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) {
|
||||
noise[index_noise] *= -1;
|
||||
}
|
||||
client_keys->at(fl_id) = noise;
|
||||
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
||||
MS_LOG(INFO) << " fl_id : " << fl_id;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,7 +61,7 @@ struct RoundConfig {
|
|||
|
||||
struct CipherConfig {
|
||||
float share_secrets_ratio = 1.0;
|
||||
float get_model_ratio = 1.0;
|
||||
uint64_t cipher_time_window = 300000;
|
||||
size_t reconstruct_secrets_threshhold = 0;
|
||||
};
|
||||
|
||||
|
|
|
@ -243,11 +243,7 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
|
|||
auto update_model_threshold = metadata_[name].mutable_update_model_threshold();
|
||||
*update_model_threshold = meta.update_model_threshold();
|
||||
} else if (meta.has_prime()) {
|
||||
auto prime_list = metadata_[name].mutable_prime_list();
|
||||
auto &prime_id = meta.prime().prime();
|
||||
if (prime_list->prime_size() == 0) {
|
||||
prime_list->add_prime(prime_id);
|
||||
}
|
||||
metadata_[name] = meta;
|
||||
} else if (meta.has_pair_client_keys()) {
|
||||
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
|
||||
auto &fl_id = meta.pair_client_keys().fl_id();
|
||||
|
|
|
@ -109,7 +109,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
|
|||
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
bool response = false;
|
||||
// bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
|
||||
|
@ -139,7 +139,8 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
"GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
response = DealClient(iter_num, get_clients_req, fbb);
|
||||
// response = DealClient(iter_num, get_clients_req, fbb);
|
||||
DealClient(iter_num, get_clients_req, fbb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -148,7 +149,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
|
||||
return response;
|
||||
return true;
|
||||
} // namespace ps
|
||||
|
||||
bool ClientListKernel::Reset() {
|
||||
|
|
|
@ -85,7 +85,10 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
|
||||
return response;
|
||||
if (!response) {
|
||||
MS_LOG(INFO) << "ExchangeKeysKernel response is false.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ExchangeKeysKernel::Reset() {
|
||||
|
|
|
@ -82,7 +82,10 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
|||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
|
||||
return response;
|
||||
if (!response) {
|
||||
MS_LOG(INFO) << "GetKeysKernel response is false.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GetKeysKernel::Reset() {
|
||||
|
|
|
@ -48,7 +48,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is "
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
|
||||
<< total_duration;
|
||||
|
||||
clock_t start_time = clock();
|
||||
|
@ -84,7 +84,10 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
|||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
|
||||
return response;
|
||||
if (!response) {
|
||||
MS_LOG(INFO) << "GetSecretsKernel response is false.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GetSecretsKernel::Reset() {
|
||||
|
|
|
@ -76,7 +76,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void *req_data = inputs[0]->addr;
|
||||
|
@ -94,7 +94,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
"update_model_client_num is zero.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
const PBMetadata client_list_pb_out =
|
||||
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||
|
@ -110,7 +110,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
"ReconstructSecretsKernel : client list is not ready", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return response;
|
||||
return true;
|
||||
}
|
||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
||||
client_list.push_back(client_list_pb.fl_id(i));
|
||||
|
@ -128,7 +128,10 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "reconstruct_secrets_kernel success time is : " << duration;
|
||||
return response;
|
||||
if (!response) {
|
||||
MS_LOG(INFO) << "reconstruct_secrets_kernel response is false.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
|
|
|
@ -85,7 +85,10 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
|
||||
return response;
|
||||
if (!response) {
|
||||
MS_LOG(INFO) << "share_secrets_kernel response is false.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShareSecretsKernel::Reset() {
|
||||
|
|
|
@ -204,7 +204,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
float dp_eps = param->dp_eps;
|
||||
float dp_delta = param->dp_delta;
|
||||
float dp_norm_clip = param->dp_norm_clip;
|
||||
auto encrypt_type = fbb->CreateString(param->encrypt_type);
|
||||
auto encrypt_type = fbb->CreateString(PSContext::instance()->encrypt_type());
|
||||
|
||||
auto cipher_public_params =
|
||||
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
|
||||
|
|
|
@ -84,7 +84,11 @@ void Server::Run() {
|
|||
RegisterCommCallbacks();
|
||||
StartCommunicator();
|
||||
InitExecutor();
|
||||
InitCipher();
|
||||
std::string encrypt_type = PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type != kNotEncryptType) {
|
||||
InitCipher();
|
||||
MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
|
||||
}
|
||||
RegisterRoundKernel();
|
||||
MS_LOG(INFO) << "Server started successfully.";
|
||||
safemode_ = false;
|
||||
|
@ -182,40 +186,46 @@ void Server::InitIteration() {
|
|||
iteration_->AddRound(round);
|
||||
}
|
||||
|
||||
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
|
||||
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
|
||||
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
|
||||
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
|
||||
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
|
||||
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold;
|
||||
|
||||
MS_LOG(INFO) << "Initializing cipher:";
|
||||
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
|
||||
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
|
||||
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
|
||||
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
|
||||
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
|
||||
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
|
||||
|
||||
#ifdef ENABLE_ARMOUR
|
||||
std::shared_ptr<Round> exchange_keys_round =
|
||||
std::make_shared<Round>("exchangeKeys", false, 3000, true, cipher_exchange_secrets_cnt_);
|
||||
iteration_->AddRound(exchange_keys_round);
|
||||
std::shared_ptr<Round> get_keys_round =
|
||||
std::make_shared<Round>("getKeys", false, 3000, true, cipher_exchange_secrets_cnt_);
|
||||
iteration_->AddRound(get_keys_round);
|
||||
std::shared_ptr<Round> share_secrets_round =
|
||||
std::make_shared<Round>("shareSecrets", false, 3000, true, cipher_share_secrets_cnt_);
|
||||
iteration_->AddRound(share_secrets_round);
|
||||
std::shared_ptr<Round> get_secrets_round =
|
||||
std::make_shared<Round>("getSecrets", false, 3000, true, cipher_share_secrets_cnt_);
|
||||
iteration_->AddRound(get_secrets_round);
|
||||
std::shared_ptr<Round> get_clientlist_round =
|
||||
std::make_shared<Round>("getClientList", false, 3000, true, cipher_get_clientlist_cnt_);
|
||||
iteration_->AddRound(get_clientlist_round);
|
||||
std::shared_ptr<Round> reconstruct_secrets_round =
|
||||
std::make_shared<Round>("reconstructSecrets", false, 3000, true, cipher_reconstruct_secrets_up_cnt_);
|
||||
iteration_->AddRound(reconstruct_secrets_round);
|
||||
std::string encrypt_type = PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type == kPWEncryptType) {
|
||||
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
|
||||
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
|
||||
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
|
||||
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
|
||||
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
|
||||
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold;
|
||||
cipher_time_window_ = cipher_config_.cipher_time_window;
|
||||
|
||||
MS_LOG(INFO) << "Initializing cipher:";
|
||||
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
|
||||
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
|
||||
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
|
||||
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
|
||||
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
|
||||
<< " cipher_time_window_: " << cipher_time_window_
|
||||
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
|
||||
|
||||
std::shared_ptr<Round> exchange_keys_round =
|
||||
std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
|
||||
iteration_->AddRound(exchange_keys_round);
|
||||
std::shared_ptr<Round> get_keys_round =
|
||||
std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
|
||||
iteration_->AddRound(get_keys_round);
|
||||
std::shared_ptr<Round> share_secrets_round =
|
||||
std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
|
||||
iteration_->AddRound(share_secrets_round);
|
||||
std::shared_ptr<Round> get_secrets_round =
|
||||
std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
|
||||
iteration_->AddRound(get_secrets_round);
|
||||
std::shared_ptr<Round> get_clientlist_round =
|
||||
std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_);
|
||||
iteration_->AddRound(get_clientlist_round);
|
||||
std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>(
|
||||
"reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_);
|
||||
iteration_->AddRound(reconstruct_secrets_round);
|
||||
MS_LOG(INFO) << "Cipher rounds has been added.";
|
||||
}
|
||||
#endif
|
||||
|
||||
// 2.Initialize all the rounds.
|
||||
|
|
|
@ -170,6 +170,7 @@ class Server {
|
|||
size_t cipher_get_clientlist_cnt_;
|
||||
size_t cipher_reconstruct_secrets_up_cnt_;
|
||||
size_t cipher_reconstruct_secrets_down_cnt_;
|
||||
uint64_t cipher_time_window_;
|
||||
|
||||
float percent_for_update_model_;
|
||||
float percent_for_get_model_;
|
||||
|
|
|
@ -673,10 +673,10 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
{"pushWeight", false, 3000, true, server_num, true}};
|
||||
|
||||
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
|
||||
float get_model_ratio = ps::PSContext::instance()->get_model_ratio();
|
||||
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
|
||||
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
|
||||
|
||||
ps::server::CipherConfig cipher_config = {share_secrets_ratio, get_model_ratio, reconstruct_secrets_threshhold};
|
||||
ps::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
|
||||
|
||||
size_t executor_threshold = 0;
|
||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||
|
|
|
@ -371,8 +371,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_share_secrets_ratio", &PSContext::set_share_secrets_ratio,
|
||||
"Set threshold count ratio for share secrets round.")
|
||||
.def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.")
|
||||
.def("set_get_model_ratio", &PSContext::set_get_model_ratio, "Set threshold count ratio for get model round.")
|
||||
.def("get_model_ratio", &PSContext::get_model_ratio, "Get threshold count ratio for get model round.")
|
||||
.def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.")
|
||||
.def("set_reconstruct_secrets_threshhold", &PSContext::set_reconstruct_secrets_threshhold,
|
||||
"Set threshold count for reconstruct secrets round.")
|
||||
.def("reconstruct_secrets_threshhold", &PSContext::reconstruct_secrets_threshhold,
|
||||
|
|
|
@ -105,10 +105,6 @@ message ClientNoises {
|
|||
OneClientNoises one_client_noises = 1;
|
||||
}
|
||||
|
||||
message PrimeList {
|
||||
repeated bytes prime = 1;
|
||||
}
|
||||
|
||||
message PairClientKeys {
|
||||
string fl_id = 1;
|
||||
KeysPb client_keys = 2;
|
||||
|
@ -156,7 +152,6 @@ message PBMetadata {
|
|||
ClientNoises client_noises = 11;
|
||||
|
||||
Prime prime = 12;
|
||||
PrimeList prime_list = 13;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -339,9 +339,14 @@ void PSContext::set_share_secrets_ratio(float share_secrets_ratio) { share_secre
|
|||
|
||||
float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
|
||||
|
||||
void PSContext::set_get_model_ratio(float get_model_ratio) { get_model_ratio_ = get_model_ratio; }
|
||||
void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
|
||||
if (cipher_time_window_ < 0) {
|
||||
MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0..";
|
||||
}
|
||||
cipher_time_window_ = cipher_time_window;
|
||||
}
|
||||
|
||||
float PSContext::get_model_ratio() const { return get_model_ratio_; }
|
||||
uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
|
||||
|
||||
void PSContext::set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold) {
|
||||
reconstruct_secrets_threshhold_ = reconstruct_secrets_threshhold;
|
||||
|
|
|
@ -131,8 +131,8 @@ class PSContext {
|
|||
void set_share_secrets_ratio(float share_secrets_ratio);
|
||||
float share_secrets_ratio() const;
|
||||
|
||||
void set_get_model_ratio(float get_model_ratio);
|
||||
float get_model_ratio() const;
|
||||
void set_cipher_time_window(uint64_t cipher_time_window);
|
||||
uint64_t cipher_time_window() const;
|
||||
|
||||
void set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold);
|
||||
uint64_t reconstruct_secrets_threshhold() const;
|
||||
|
@ -201,7 +201,7 @@ class PSContext {
|
|||
update_model_ratio_(1.0),
|
||||
update_model_time_window_(3000),
|
||||
share_secrets_ratio_(1.0),
|
||||
get_model_ratio_(1.0),
|
||||
cipher_time_window_(300000),
|
||||
reconstruct_secrets_threshhold_(2000),
|
||||
fl_iteration_num_(20),
|
||||
client_epoch_num_(25),
|
||||
|
@ -260,8 +260,8 @@ class PSContext {
|
|||
// Share model threshold is a certain ratio of share secrets threshold which is set as share_secrets_ratio_.
|
||||
float share_secrets_ratio_;
|
||||
|
||||
// Get model threshold is a certain ratio of get model threshold which is set as get_model_ratio_.
|
||||
float get_model_ratio_;
|
||||
// The time window of each cipher round in millisecond.
|
||||
uint64_t cipher_time_window_;
|
||||
|
||||
// The threshold count of reconstruct secrets round. Used in federated learning for now.
|
||||
uint64_t reconstruct_secrets_threshhold_;
|
||||
|
|
|
@ -840,7 +840,7 @@ def set_fl_context(**kwargs):
|
|||
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
|
||||
share_secrets_ratio (float): The ratio for computing the threshold count of share secrets. Default: 1.0.
|
||||
update_model_ratio (float): The ratio for computing the threshold count of updateModel. Default: 1.0.
|
||||
get_model_ratio (float): The ratio for computing the threshold count of get model. Default: 1.0.
|
||||
cipher_time_window (int): The time window duration for each cipher round in millisecond. Default: 300000.
|
||||
reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 0.
|
||||
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
|
||||
fl_name (string): The federated learning job name. Default: ''.
|
||||
|
|
|
@ -58,7 +58,7 @@ _set_ps_context_func_map = {
|
|||
"update_model_ratio": ps_context().set_update_model_ratio,
|
||||
"update_model_time_window": ps_context().set_update_model_time_window,
|
||||
"share_secrets_ratio": ps_context().set_share_secrets_ratio,
|
||||
"get_model_ratio": ps_context().set_get_model_ratio,
|
||||
"cipher_time_window": ps_context().set_cipher_time_window,
|
||||
"reconstruct_secrets_threshhold": ps_context().set_reconstruct_secrets_threshhold,
|
||||
"fl_name": ps_context().set_fl_name,
|
||||
"fl_iteration_num": ps_context().set_fl_iteration_num,
|
||||
|
@ -91,7 +91,7 @@ _get_ps_context_func_map = {
|
|||
"update_model_ratio": ps_context().update_model_ratio,
|
||||
"update_model_time_window": ps_context().update_model_time_window,
|
||||
"share_secrets_ratio": ps_context().share_secrets_ratio,
|
||||
"get_model_ratio": ps_context().get_model_ratio,
|
||||
"cipher_time_window": ps_context().set_cipher_time_window,
|
||||
"reconstruct_secrets_threshhold": ps_context().reconstruct_secrets_threshhold,
|
||||
"fl_name": ps_context().fl_name,
|
||||
"fl_iteration_num": ps_context().fl_iteration_num,
|
||||
|
|
|
@ -35,6 +35,15 @@ parser.add_argument("--client_batch_size", type=int, default=32)
|
|||
parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
||||
parser.add_argument("--local_server_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
# parameters for encrypt_type='DP_ENCRYPT'
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
# parameters for encrypt_type='PW_ENCRYPT'
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -55,6 +64,13 @@ client_batch_size = args.client_batch_size
|
|||
client_learning_rate = args.client_learning_rate
|
||||
local_server_num = args.local_server_num
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
|
||||
if local_server_num == -1:
|
||||
local_server_num = server_num
|
||||
|
@ -85,6 +101,13 @@ for i in range(local_server_num):
|
|||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||
cmd_server += " --client_batch_size=" + str(client_batch_size)
|
||||
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_server += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
|
||||
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
|
||||
cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold)
|
||||
cmd_server += " --dp_eps=" + str(dp_eps)
|
||||
cmd_server += " --dp_delta=" + str(dp_delta)
|
||||
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
|
||||
cmd_server += " > server.log 2>&1 &"
|
||||
|
||||
import time
|
||||
|
|
|
@ -45,6 +45,15 @@ parser.add_argument("--client_learning_rate", type=float, default=0.1)
|
|||
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
||||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
# parameters for encrypt_type='DP_ENCRYPT'
|
||||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
# parameters for encrypt_type='PW_ENCRYPT'
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -67,6 +76,13 @@ client_learning_rate = args.client_learning_rate
|
|||
worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
||||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
|
||||
dp_eps = args.dp_eps
|
||||
dp_delta = args.dp_delta
|
||||
dp_norm_clip = args.dp_norm_clip
|
||||
|
||||
ctx = {
|
||||
"enable_fl": True,
|
||||
|
@ -88,7 +104,14 @@ ctx = {
|
|||
"client_learning_rate": client_learning_rate,
|
||||
"worker_step_num_per_iteration": worker_step_num_per_iteration,
|
||||
"scheduler_manage_port": scheduler_manage_port,
|
||||
"config_file_path": config_file_path
|
||||
"config_file_path": config_file_path,
|
||||
"share_secrets_ratio": share_secrets_ratio,
|
||||
"cipher_time_window": cipher_time_window,
|
||||
"reconstruct_secrets_threshhold": reconstruct_secrets_threshhold,
|
||||
"dp_eps": dp_eps,
|
||||
"dp_delta": dp_delta,
|
||||
"dp_norm_clip": dp_norm_clip,
|
||||
"encrypt_type": encrypt_type
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)
|
||||
|
|
|
@ -28,9 +28,6 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
|||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--get_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
|
@ -43,6 +40,10 @@ parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
|||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
# parameters for encrypt_type='PW_ENCRYPT'
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
|
@ -58,7 +59,7 @@ if __name__ == "__main__":
|
|||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
get_model_ratio = args.get_model_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
|
@ -96,7 +97,7 @@ if __name__ == "__main__":
|
|||
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
|
||||
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
|
||||
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
|
||||
cmd_server += " --get_model_ratio=" + str(get_model_ratio)
|
||||
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
|
||||
cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold)
|
||||
cmd_server += " --fl_name=" + fl_name
|
||||
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
|
|
|
@ -36,9 +36,6 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
|||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--get_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
|
@ -51,6 +48,10 @@ parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
|||
parser.add_argument("--dp_eps", type=float, default=50.0)
|
||||
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
|
||||
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
|
||||
# parameters for encrypt_type='PW_ENCRYPT'
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -66,7 +67,7 @@ start_fl_job_time_window = args.start_fl_job_time_window
|
|||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
get_model_ratio = args.get_model_ratio
|
||||
cipher_time_window = args.cipher_time_window
|
||||
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
|
@ -94,7 +95,7 @@ ctx = {
|
|||
"update_model_ratio": update_model_ratio,
|
||||
"update_model_time_window": update_model_time_window,
|
||||
"share_secrets_ratio": share_secrets_ratio,
|
||||
"get_model_ratio": get_model_ratio,
|
||||
"cipher_time_window": cipher_time_window,
|
||||
"reconstruct_secrets_threshhold": reconstruct_secrets_threshhold,
|
||||
"fl_name": fl_name,
|
||||
"fl_iteration_num": fl_iteration_num,
|
||||
|
|
Loading…
Reference in New Issue