Add secure parameters for mindspore federated learning.

fix prime initialization and build key bug

fix reconstruct access bug and kernels' retcode.

Fix init issue

Add fl secure parameter cipher_time_window
This commit is contained in:
jin-xiulang 2021-07-03 13:26:19 +08:00
parent e372634d16
commit d2b42fd12a
27 changed files with 203 additions and 113 deletions

View File

@ -15,8 +15,10 @@
*/
#include "fl/armour/cipher/cipher_init.h"
#include "fl/server/common.h"
#include "fl/armour/cipher/cipher_meta_storage.h"
#include "fl/server/common.h"
#include "fl/server/model_store.h"
namespace mindspore {
namespace armour {
@ -43,8 +45,7 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
publicparam_.t = param.t;
secrets_minnums_ = param.t;
client_num_need_ = cipher_initial_client_cnt;
featuremap_ = 1000; // todo: wait for other code
// merge.ps::server::DistributedMetadataStore::GetInstance().model_size() / sizeof(float);
featuremap_ = ps::server::ModelStore::GetInstance().model_size() / sizeof(float);
share_clients_num_need_ = cipher_share_secrets_cnt;
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
get_model_num_need_ = cipher_get_clientlist_cnt;

View File

@ -38,14 +38,14 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
std::string fl_id = get_exchange_keys_req->fl_id()->str();
if (find(clients.begin(), clients.end(), fl_id) == clients.end()) {
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
MS_LOG(INFO) << "The fl_id is not in clients.";
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
return false;
}
if (cur_clients_num < cipher_init_->client_num_need_) {
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
MS_LOG(INFO) << "The server is not ready yet: cur_clients_num < client_num_need";
MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << "cur_clients_num : " << cipher_init_->client_num_need_;
MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << ", cur_clients_num : " << cipher_init_->client_num_need_;
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
return false;
}
@ -55,7 +55,7 @@ bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_tim
}
bool flag =
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, cur_iterator, next_req_time, true);
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
return flag;
} // namespace armour
@ -95,17 +95,17 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
MS_LOG(INFO) << "client_num_need_ " << cipher_init_->client_num_need_ << ". cur_clients_num " << cur_clients_num;
std::string fl_id = exchange_keys_req->fl_id()->str();
if (cur_clients_num >= cipher_init_->client_num_need_) { // the client num is enough, return false.
MS_LOG(ERROR) << "The server has received enough requests and refuse this request.";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime,
"The server has received enough requests and refuse this request.", next_req_time,
cur_iterator);
MS_LOG(ERROR) << "The server has received enough requests and refuse this request.";
return false;
}
if (record_public_keys.find(fl_id) != record_public_keys.end()) { // the client already exists, return false.
MS_LOG(INFO) << "The server has received the request, please do not request again.";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
"The server has received the request, please do not request again.", next_req_time,
cur_iterator);
MS_LOG(INFO) << "The server has received the request, please do not request again.";
return false;
}
@ -135,9 +135,9 @@ bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_re
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(ps::server::kCtxExChangeKeysClientList, fl_id);
if (retcode_key && retcode_client) {
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
"Success, but the server is not ready yet.", next_req_time, cur_iterator);
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
return true;
} else {
MS_LOG(ERROR) << "update key or client failed";
@ -164,7 +164,6 @@ void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<ps::server::FBBuilder> exc
bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good) {
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
bool flag = true;
if (is_good) {
// convert client keys to standard keys list.
@ -176,6 +175,14 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
<< "clients num: " << cipher_init_->client_num_need_;
flag = false;
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
} else {
for (auto iter = record_public_keys.begin(); iter != record_public_keys.end(); ++iter) {
// read (fl_id, c_pk, s_pk) from the map: record_public_keys_
@ -190,18 +197,26 @@ bool CipherKeys::BuildGetKeys(std::shared_ptr<ps::server::FBBuilder> fbb, const
public_keys_list.push_back(cur_public_key);
}
auto remote_publickeys = fbb->CreateVector(public_keys_list);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_remote_publickeys(remote_publickeys);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
flag = true;
}
}
} else {
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
auto fbs_next_req_time = fbb->CreateString(next_req_time);
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
fbb->Finish(rsp_get_keys);
}
return flag;
}

View File

@ -82,16 +82,19 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
return true;
}
bool CipherMetaStorage::GetPrimeFromServer(const char *list_name, unsigned char *prime) {
const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
auto &prime_list_pb = prime_pb_out.prime_list();
if (prime_list_pb.prime_size() > 0 && prime_list_pb.prime(0).size() >= PRIME_MAX_LEN) {
for (int i = 0; i < PRIME_MAX_LEN; i++) {
prime[i] = static_cast<unsigned char>(prime_list_pb.prime(0)[i]);
}
bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char *prime) {
const ps::PBMetadata &prime_pb_out = ps::server::DistributedMetadataStore::GetInstance().GetMetadata(prime_name);
ps::Prime prime_pb(prime_pb_out.prime());
std::string str = *(prime_pb.mutable_prime());
MS_LOG(INFO) << "get prime from metastorage :" << str;
if (str.size() != PRIME_MAX_LEN) {
MS_LOG(ERROR) << "get prime size is :" << str.size();
return false;
} else {
memcpy_s(prime, PRIME_MAX_LEN, str.data(), PRIME_MAX_LEN);
return true;
}
return false;
}
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
@ -104,6 +107,7 @@ bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::s
return retcode;
}
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
MS_LOG(INFO) << "register prime: " << prime;
ps::Prime prime_id_pb;
prime_id_pb.set_prime(prime);
ps::PBMetadata prime_pb;

View File

@ -65,7 +65,7 @@ class CipherMetaStorage {
// Register Prime.
void RegisterPrime(const char *list_name, const std::string &prime);
// Get tprime from shared server.
bool GetPrimeFromServer(const char *list_name, unsigned char *prime);
bool GetPrimeFromServer(const char *prime_name, unsigned char *prime);
// Get client shares from shared server.
void GetClientSharesFromServer(const char *list_name,
std::map<std::string, std::vector<clientshare_str>> *clients_shares_list);

View File

@ -77,7 +77,7 @@ bool CipherReconStruct::CombineMask(
std::vector<float> noise(cipher_init_->featuremap_, 0.0);
if (GetSuvNoise(clients_share_list, record_public_keys, fl_id, &noise, secret, length) == false)
retcode = false;
client_keys->at(fl_id) = noise;
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
MS_LOG(INFO) << " fl_id : " << fl_id;
MS_LOG(INFO) << "end get complete s_uv.";
} else {
@ -87,7 +87,7 @@ bool CipherReconStruct::CombineMask(
for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) {
noise[index_noise] *= -1;
}
client_keys->at(fl_id) = noise;
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
MS_LOG(INFO) << " fl_id : " << fl_id;
}
}

View File

@ -61,7 +61,7 @@ struct RoundConfig {
struct CipherConfig {
float share_secrets_ratio = 1.0;
float get_model_ratio = 1.0;
uint64_t cipher_time_window = 300000;
size_t reconstruct_secrets_threshhold = 0;
};

View File

@ -243,11 +243,7 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
auto update_model_threshold = metadata_[name].mutable_update_model_threshold();
*update_model_threshold = meta.update_model_threshold();
} else if (meta.has_prime()) {
auto prime_list = metadata_[name].mutable_prime_list();
auto &prime_id = meta.prime().prime();
if (prime_list->prime_size() == 0) {
prime_list->add_prime(prime_id);
}
metadata_[name] = meta;
} else if (meta.has_pair_client_keys()) {
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
auto &fl_id = meta.pair_client_keys().fl_id();

View File

@ -109,7 +109,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
bool response = false;
// bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
@ -139,7 +139,8 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
"GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
response = DealClient(iter_num, get_clients_req, fbb);
// response = DealClient(iter_num, get_clients_req, fbb);
DealClient(iter_num, get_clients_req, fbb);
}
}
}
@ -148,7 +149,7 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
return response;
return true;
} // namespace ps
bool ClientListKernel::Reset() {

View File

@ -85,7 +85,10 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
return response;
if (!response) {
MS_LOG(INFO) << "ExchangeKeysKernel response is false.";
}
return true;
}
bool ExchangeKeysKernel::Reset() {

View File

@ -82,7 +82,10 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
return response;
if (!response) {
MS_LOG(INFO) << "GetKeysKernel response is false.";
}
return true;
}
bool GetKeysKernel::Reset() {

View File

@ -48,7 +48,7 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is "
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
<< total_duration;
clock_t start_time = clock();
@ -84,7 +84,10 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
return response;
if (!response) {
MS_LOG(INFO) << "GetSecretsKernel response is false.";
}
return true;
}
bool GetSecretsKernel::Reset() {

View File

@ -76,7 +76,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
return true;
}
void *req_data = inputs[0]->addr;
@ -94,7 +94,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
"update_model_client_num is zero.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
return true;
}
const PBMetadata client_list_pb_out =
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
@ -110,7 +110,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
"ReconstructSecretsKernel : client list is not ready", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return response;
return true;
}
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
@ -128,7 +128,10 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "reconstruct_secrets_kernel success time is : " << duration;
return response;
if (!response) {
MS_LOG(INFO) << "reconstruct_secrets_kernel response is false.";
}
return true;
}
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {

View File

@ -85,7 +85,10 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
return response;
if (!response) {
MS_LOG(INFO) << "share_secrets_kernel response is false.";
}
return true;
}
bool ShareSecretsKernel::Reset() {

View File

@ -204,7 +204,7 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
float dp_eps = param->dp_eps;
float dp_delta = param->dp_delta;
float dp_norm_clip = param->dp_norm_clip;
auto encrypt_type = fbb->CreateString(param->encrypt_type);
auto encrypt_type = fbb->CreateString(PSContext::instance()->encrypt_type());
auto cipher_public_params =
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);

View File

@ -84,7 +84,11 @@ void Server::Run() {
RegisterCommCallbacks();
StartCommunicator();
InitExecutor();
InitCipher();
std::string encrypt_type = PSContext::instance()->encrypt_type();
if (encrypt_type != kNotEncryptType) {
InitCipher();
MS_LOG(INFO) << "Parameters for secure aggregation have been initiated.";
}
RegisterRoundKernel();
MS_LOG(INFO) << "Server started successfully.";
safemode_ = false;
@ -182,40 +186,46 @@ void Server::InitIteration() {
iteration_->AddRound(round);
}
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold;
MS_LOG(INFO) << "Initializing cipher:";
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
#ifdef ENABLE_ARMOUR
std::shared_ptr<Round> exchange_keys_round =
std::make_shared<Round>("exchangeKeys", false, 3000, true, cipher_exchange_secrets_cnt_);
iteration_->AddRound(exchange_keys_round);
std::shared_ptr<Round> get_keys_round =
std::make_shared<Round>("getKeys", false, 3000, true, cipher_exchange_secrets_cnt_);
iteration_->AddRound(get_keys_round);
std::shared_ptr<Round> share_secrets_round =
std::make_shared<Round>("shareSecrets", false, 3000, true, cipher_share_secrets_cnt_);
iteration_->AddRound(share_secrets_round);
std::shared_ptr<Round> get_secrets_round =
std::make_shared<Round>("getSecrets", false, 3000, true, cipher_share_secrets_cnt_);
iteration_->AddRound(get_secrets_round);
std::shared_ptr<Round> get_clientlist_round =
std::make_shared<Round>("getClientList", false, 3000, true, cipher_get_clientlist_cnt_);
iteration_->AddRound(get_clientlist_round);
std::shared_ptr<Round> reconstruct_secrets_round =
std::make_shared<Round>("reconstructSecrets", false, 3000, true, cipher_reconstruct_secrets_up_cnt_);
iteration_->AddRound(reconstruct_secrets_round);
std::string encrypt_type = PSContext::instance()->encrypt_type();
if (encrypt_type == kPWEncryptType) {
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold;
cipher_time_window_ = cipher_config_.cipher_time_window;
MS_LOG(INFO) << "Initializing cipher:";
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
<< " cipher_time_window_: " << cipher_time_window_
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
std::shared_ptr<Round> exchange_keys_round =
std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
iteration_->AddRound(exchange_keys_round);
std::shared_ptr<Round> get_keys_round =
std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
iteration_->AddRound(get_keys_round);
std::shared_ptr<Round> share_secrets_round =
std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
iteration_->AddRound(share_secrets_round);
std::shared_ptr<Round> get_secrets_round =
std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
iteration_->AddRound(get_secrets_round);
std::shared_ptr<Round> get_clientlist_round =
std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_);
iteration_->AddRound(get_clientlist_round);
std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>(
"reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_);
iteration_->AddRound(reconstruct_secrets_round);
MS_LOG(INFO) << "Cipher rounds has been added.";
}
#endif
// 2.Initialize all the rounds.

View File

@ -170,6 +170,7 @@ class Server {
size_t cipher_get_clientlist_cnt_;
size_t cipher_reconstruct_secrets_up_cnt_;
size_t cipher_reconstruct_secrets_down_cnt_;
uint64_t cipher_time_window_;
float percent_for_update_model_;
float percent_for_get_model_;

View File

@ -673,10 +673,10 @@ bool StartServerAction(const ResourcePtr &res) {
{"pushWeight", false, 3000, true, server_num, true}};
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
float get_model_ratio = ps::PSContext::instance()->get_model_ratio();
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
ps::server::CipherConfig cipher_config = {share_secrets_ratio, get_model_ratio, reconstruct_secrets_threshhold};
ps::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshhold};
size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {

View File

@ -371,8 +371,7 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_share_secrets_ratio", &PSContext::set_share_secrets_ratio,
"Set threshold count ratio for share secrets round.")
.def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.")
.def("set_get_model_ratio", &PSContext::set_get_model_ratio, "Set threshold count ratio for get model round.")
.def("get_model_ratio", &PSContext::get_model_ratio, "Get threshold count ratio for get model round.")
.def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.")
.def("set_reconstruct_secrets_threshhold", &PSContext::set_reconstruct_secrets_threshhold,
"Set threshold count for reconstruct secrets round.")
.def("reconstruct_secrets_threshhold", &PSContext::reconstruct_secrets_threshhold,

View File

@ -105,10 +105,6 @@ message ClientNoises {
OneClientNoises one_client_noises = 1;
}
message PrimeList {
repeated bytes prime = 1;
}
message PairClientKeys {
string fl_id = 1;
KeysPb client_keys = 2;
@ -156,7 +152,6 @@ message PBMetadata {
ClientNoises client_noises = 11;
Prime prime = 12;
PrimeList prime_list = 13;
}
}

View File

@ -339,9 +339,14 @@ void PSContext::set_share_secrets_ratio(float share_secrets_ratio) { share_secre
float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
void PSContext::set_get_model_ratio(float get_model_ratio) { get_model_ratio_ = get_model_ratio; }
void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
if (cipher_time_window_ < 0) {
MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0..";
}
cipher_time_window_ = cipher_time_window;
}
float PSContext::get_model_ratio() const { return get_model_ratio_; }
uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
void PSContext::set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold) {
reconstruct_secrets_threshhold_ = reconstruct_secrets_threshhold;

View File

@ -131,8 +131,8 @@ class PSContext {
void set_share_secrets_ratio(float share_secrets_ratio);
float share_secrets_ratio() const;
void set_get_model_ratio(float get_model_ratio);
float get_model_ratio() const;
void set_cipher_time_window(uint64_t cipher_time_window);
uint64_t cipher_time_window() const;
void set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold);
uint64_t reconstruct_secrets_threshhold() const;
@ -201,7 +201,7 @@ class PSContext {
update_model_ratio_(1.0),
update_model_time_window_(3000),
share_secrets_ratio_(1.0),
get_model_ratio_(1.0),
cipher_time_window_(300000),
reconstruct_secrets_threshhold_(2000),
fl_iteration_num_(20),
client_epoch_num_(25),
@ -260,8 +260,8 @@ class PSContext {
// Share model threshold is a certain ratio of share secrets threshold which is set as share_secrets_ratio_.
float share_secrets_ratio_;
// Get model threshold is a certain ratio of get model threshold which is set as get_model_ratio_.
float get_model_ratio_;
// The time window of each cipher round in millisecond.
uint64_t cipher_time_window_;
// The threshold count of reconstruct secrets round. Used in federated learning for now.
uint64_t reconstruct_secrets_threshhold_;

View File

@ -840,7 +840,7 @@ def set_fl_context(**kwargs):
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
share_secrets_ratio (float): The ratio for computing the threshold count of share secrets. Default: 1.0.
update_model_ratio (float): The ratio for computing the threshold count of updateModel. Default: 1.0.
get_model_ratio (float): The ratio for computing the threshold count of get model. Default: 1.0.
cipher_time_window (int): The time window duration for each cipher round in millisecond. Default: 300000.
reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 0.
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
fl_name (string): The federated learning job name. Default: ''.

View File

@ -58,7 +58,7 @@ _set_ps_context_func_map = {
"update_model_ratio": ps_context().set_update_model_ratio,
"update_model_time_window": ps_context().set_update_model_time_window,
"share_secrets_ratio": ps_context().set_share_secrets_ratio,
"get_model_ratio": ps_context().set_get_model_ratio,
"cipher_time_window": ps_context().set_cipher_time_window,
"reconstruct_secrets_threshhold": ps_context().set_reconstruct_secrets_threshhold,
"fl_name": ps_context().set_fl_name,
"fl_iteration_num": ps_context().set_fl_iteration_num,
@ -91,7 +91,7 @@ _get_ps_context_func_map = {
"update_model_ratio": ps_context().update_model_ratio,
"update_model_time_window": ps_context().update_model_time_window,
"share_secrets_ratio": ps_context().share_secrets_ratio,
"get_model_ratio": ps_context().get_model_ratio,
"cipher_time_window": ps_context().set_cipher_time_window,
"reconstruct_secrets_threshhold": ps_context().reconstruct_secrets_threshhold,
"fl_name": ps_context().fl_name,
"fl_iteration_num": ps_context().fl_iteration_num,

View File

@ -35,6 +35,15 @@ parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -55,6 +64,13 @@ client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
if local_server_num == -1:
local_server_num = server_num
@ -85,6 +101,13 @@ for i in range(local_server_num):
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold)
cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " > server.log 2>&1 &"
import time

View File

@ -45,6 +45,15 @@ parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -67,6 +76,13 @@ client_learning_rate = args.client_learning_rate
worker_step_num_per_iteration = args.worker_step_num_per_iteration
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
share_secrets_ratio = args.share_secrets_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
ctx = {
"enable_fl": True,
@ -88,7 +104,14 @@ ctx = {
"client_learning_rate": client_learning_rate,
"worker_step_num_per_iteration": worker_step_num_per_iteration,
"scheduler_manage_port": scheduler_manage_port,
"config_file_path": config_file_path
"config_file_path": config_file_path,
"share_secrets_ratio": share_secrets_ratio,
"cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshhold": reconstruct_secrets_threshhold,
"dp_eps": dp_eps,
"dp_delta": dp_delta,
"dp_norm_clip": dp_norm_clip,
"encrypt_type": encrypt_type
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)

View File

@ -28,9 +28,6 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--get_model_ratio", type=float, default=1.0)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
@ -43,6 +40,10 @@ parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
if __name__ == "__main__":
args, _ = parser.parse_known_args()
@ -58,7 +59,7 @@ if __name__ == "__main__":
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
share_secrets_ratio = args.share_secrets_ratio
get_model_ratio = args.get_model_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
@ -96,7 +97,7 @@ if __name__ == "__main__":
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
cmd_server += " --get_model_ratio=" + str(get_model_ratio)
cmd_server += " --cipher_time_window=" + str(cipher_time_window)
cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold)
cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)

View File

@ -36,9 +36,6 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--get_model_ratio", type=float, default=1.0)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
@ -51,6 +48,10 @@ parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)
# parameters for encrypt_type='PW_ENCRYPT'
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=3)
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -66,7 +67,7 @@ start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
share_secrets_ratio = args.share_secrets_ratio
get_model_ratio = args.get_model_ratio
cipher_time_window = args.cipher_time_window
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
@ -94,7 +95,7 @@ ctx = {
"update_model_ratio": update_model_ratio,
"update_model_time_window": update_model_time_window,
"share_secrets_ratio": share_secrets_ratio,
"get_model_ratio": get_model_ratio,
"cipher_time_window": cipher_time_window,
"reconstruct_secrets_threshhold": reconstruct_secrets_threshhold,
"fl_name": fl_name,
"fl_iteration_num": fl_iteration_num,