Fix security problems and code-check problems for federated's secure aggregation

fix security check problems for flclient
This commit is contained in:
jin-xiulang 2021-09-10 10:30:49 +08:00
parent 0abff9ad65
commit 7d9dd343f3
68 changed files with 4321 additions and 2225 deletions

View File

@ -39,7 +39,7 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/key_agreement.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/random.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/masking.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/secret_sharing.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_init.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_keys.cc")

View File

@ -22,26 +22,28 @@
namespace mindspore {
namespace armour {
bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_initial_client_cnt,
size_t cipher_exchange_secrets_cnt, size_t cipher_share_secrets_cnt,
bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
size_t cipher_reconstruct_secrets_up_cnt) {
MS_LOG(INFO) << "CipherInit::Init START";
int return_num = 0;
return_num = memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN);
if (return_num != 0) {
if (memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN) != 0) {
MS_LOG(ERROR) << "CipherInit::memory copy failed.";
return false;
}
publicparam_.g = param.g;
publicparam_.t = param.t;
secrets_minnums_ = param.t;
client_num_need_ = cipher_initial_client_cnt;
featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float);
share_clients_num_need_ = cipher_share_secrets_cnt;
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
get_model_num_need_ = cipher_get_clientlist_cnt;
exchange_key_threshold = cipher_exchange_keys_cnt;
get_key_threshold = cipher_get_keys_cnt;
share_secrets_threshold = cipher_share_secrets_cnt;
get_secrets_threshold = cipher_get_secrets_cnt;
client_list_threshold = cipher_get_clientlist_cnt;
reconstruct_secrets_threshold = cipher_reconstruct_secrets_up_cnt;
time_out_mutex_ = time_out_mutex;
publicparam_.dp_eps = param.dp_eps;
publicparam_.dp_delta = param.dp_delta;
@ -62,10 +64,12 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
return false;
}
MS_LOG(INFO) << " CipherInit client_num_need_ : " << client_num_need_;
MS_LOG(INFO) << " CipherInit share_clients_num_need_ : " << share_clients_num_need_;
MS_LOG(INFO) << " CipherInit reconstruct_clients_num_need_ : " << reconstruct_clients_num_need_;
MS_LOG(INFO) << " CipherInit get_model_num_need_ : " << get_model_num_need_;
MS_LOG(INFO) << " CipherInit exchange_key_threshold : " << exchange_key_threshold;
MS_LOG(INFO) << " CipherInit get_key_threshold : " << get_key_threshold;
MS_LOG(INFO) << " CipherInit share_secrets_threshold : " << share_secrets_threshold;
MS_LOG(INFO) << " CipherInit get_secrets_threshold : " << get_secrets_threshold;
MS_LOG(INFO) << " CipherInit client_list_threshold : " << client_list_threshold;
MS_LOG(INFO) << " CipherInit reconstruct_secrets_threshold : " << reconstruct_secrets_threshold;
MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_;
if (!Check_Parames()) {
MS_LOG(ERROR) << "Cipher parameters are illegal.";
@ -82,11 +86,10 @@ bool CipherInit::Check_Parames() {
return false;
}
if (share_clients_num_need_ < reconstruct_clients_num_need_) {
MS_LOG(ERROR)
<< "reconstruct_clients_num_need (which is reconstruct_secrets_threshold + 1) should not be larger "
"than share_clients_num_need (which is start_fl_job_threshold*share_secrets_ratio), but got they are:"
<< reconstruct_clients_num_need_ << ", " << share_clients_num_need_;
if (share_secrets_threshold < reconstruct_secrets_threshold) {
MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger "
"than share_secrets_threshold, but got they are:"
<< reconstruct_secrets_threshold << ", " << share_secrets_threshold;
return false;
}

View File

@ -29,17 +29,6 @@
namespace mindspore {
namespace armour {
template <typename T1>
bool CreateArray(std::vector<T1> *newData, const flatbuffers::Vector<T1> &fbs_arr) {
size_t size = newData->size();
size_t size_fbs_arr = fbs_arr.size();
if (size != size_fbs_arr) return false;
for (size_t i = 0; i < size; ++i) {
newData->at(i) = fbs_arr.Get(i);
}
return true;
}
// Initialization of secure aggregation.
class CipherInit {
public:
@ -49,9 +38,10 @@ class CipherInit {
}
// Initialize the parameters of the secure aggregation.
bool Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_initial_client_cnt,
size_t cipher_exchange_secrets_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_clientlist_cnt,
size_t cipher_reconstruct_secrets_down_cnt, size_t cipher_reconstruct_secrets_up_cnt);
bool Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
size_t cipher_reconstruct_secrets_up_cnt);
// Check whether the parameters are valid.
bool Check_Parames();
@ -59,10 +49,12 @@ class CipherInit {
// Get public params. which is given to start fl job thread.
CipherPublicPara *GetPublicParams() { return &publicparam_; }
size_t share_clients_num_need_; // the minimum number of clients to share secrets.
size_t reconstruct_clients_num_need_; // the minimum number of clients to reconstruct secret mask.
size_t client_num_need_; // the minimum number of clients to update model.
size_t get_model_num_need_; // the minimum number of clients to get model.
size_t share_secrets_threshold; // the minimum number of clients to share secret fragments.
size_t get_secrets_threshold; // the minimum number of clients to get secret fragments.
size_t reconstruct_secrets_threshold; // the minimum number of clients to reconstruct secret mask.
size_t exchange_key_threshold; // the minimum number of clients to send public keys.
size_t get_key_threshold; // the minimum number of clients to get public keys.
size_t client_list_threshold; // the minimum number of clients to get update model client list.
size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask.
size_t featuremap_; // the size of data to deal.

View File

@ -21,203 +21,170 @@ namespace mindspore {
namespace armour {
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &get_exchange_keys_resp_builder) {
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
MS_LOG(INFO) << "CipherMgr::GetKeys START";
if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false);
if (get_exchange_keys_req == nullptr) {
MS_LOG(ERROR) << "Request is nullptr";
BuildGetKeysRsp(fbb, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false);
return false;
}
// get clientlist from memory server.
std::vector<std::string> clients;
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &clients);
size_t cur_clients_num = clients.size();
size_t cur_exchange_clients_num = client_public_keys.size();
std::string fl_id = get_exchange_keys_req->fl_id()->str();
if (find(clients.begin(), clients.end(), fl_id) == clients.end()) {
MS_LOG(INFO) << "The fl_id is not in clients.";
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
if (cur_exchange_clients_num < cipher_init_->exchange_key_threshold) {
MS_LOG(INFO) << "The server is not ready yet: cur_exchangekey_clients_num < exchange_key_threshold";
MS_LOG(INFO) << "cur_exchangekey_clients_num : " << cur_exchange_clients_num
<< ", exchange_key_threshold : " << cipher_init_->exchange_key_threshold;
BuildGetKeysRsp(fbb, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
return false;
}
if (cur_clients_num < cipher_init_->client_num_need_) {
MS_LOG(INFO) << "The server is not ready yet: cur_clients_num < client_num_need";
MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << ", client_num_need : " << cipher_init_->client_num_need_;
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
if (client_public_keys.find(fl_id) == client_public_keys.end()) {
MS_LOG(INFO) << "Get keys: the fl_id: " << fl_id << "is not in exchange keys clients.";
BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
return false;
}
bool ret = cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetKeysClientList, fl_id);
if (!ret) {
MS_LOG(ERROR) << "update get keys clients failed";
BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, cur_iterator, next_req_time, false);
return false;
}
MS_LOG(INFO) << "GetKeys client list: ";
for (size_t i = 0; i < clients.size(); i++) {
MS_LOG(INFO) << "fl_id: " << clients[i];
}
bool flag =
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
return flag;
} // namespace armour
BuildGetKeysRsp(fbb, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
return true;
}
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder) {
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
// step 0: judge if the input param is legal.
if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
std::string reason = "Request is nullptr or Response builder is nullptr.";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time,
cur_iterator);
if (exchange_keys_req == nullptr) {
std::string reason = "Request is nullptr";
MS_LOG(ERROR) << reason;
BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason, next_req_time, cur_iterator);
return false;
}
std::string fl_id = exchange_keys_req->fl_id()->str();
mindspore::fl::PBMetadata device_metas =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxDeviceMetas);
mindspore::fl::FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
MS_LOG(INFO) << "exchange key for fl id " << fl_id;
if (fl_id_to_meta.fl_id_to_meta().count(fl_id) == 0) {
std::string reason = "devices_meta for " + fl_id + " is not set. Please retry later.";
BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator);
MS_LOG(ERROR) << reason;
return false;
}
// step 1: get clientlist and client keys from memory server.
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
std::vector<std::string> client_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
// step2: process new item data. and update new item data to memory server.
size_t cur_clients_num = client_list.size();
size_t cur_clients_has_keys_num = record_public_keys.size();
size_t cur_clients_has_keys_num = client_public_keys.size();
if (cur_clients_num != cur_clients_has_keys_num) {
std::string reason = "client num and keys num are not equal.";
MS_LOG(ERROR) << reason;
MS_LOG(ERROR) << "cur_clients_num is " << cur_clients_num << ". cur_clients_has_keys_num is "
<< cur_clients_has_keys_num;
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, reason, next_req_time,
cur_iterator);
return false;
MS_LOG(WARNING) << reason;
MS_LOG(WARNING) << "cur_clients_num is " << cur_clients_num << ". cur_clients_has_keys_num is "
<< cur_clients_has_keys_num;
}
MS_LOG(WARNING) << "exchange_key_threshold " << cipher_init_->exchange_key_threshold << ". cur_clients_num "
<< cur_clients_num << ". cur_clients_keys_num " << cur_clients_has_keys_num;
MS_LOG(INFO) << "client_num_need_ " << cipher_init_->client_num_need_ << ". cur_clients_num " << cur_clients_num;
std::string fl_id = exchange_keys_req->fl_id()->str();
if (cur_clients_num >= cipher_init_->client_num_need_) { // the client num is enough, return false.
MS_LOG(ERROR) << "The server has received enough requests and refuse this request.";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime,
"The server has received enough requests and refuse this request.", next_req_time,
cur_iterator);
return false;
}
if (record_public_keys.find(fl_id) != record_public_keys.end()) { // the client already exists, return false.
MS_LOG(INFO) << "The server has received the request, please do not request again.";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
if (client_public_keys.find(fl_id) != client_public_keys.end()) { // the client already exists, return false.
MS_LOG(ERROR) << "The server has received the request, please do not request again.";
BuildExchangeKeysRsp(fbb, schema::ResponseCode_SUCCEED,
"The server has received the request, please do not request again.", next_req_time,
cur_iterator);
return false;
}
// Gets the members of the deserialized data exchange_keys_req
auto fbs_cpk = exchange_keys_req->c_pk();
size_t cpk_len = fbs_cpk->size();
auto fbs_spk = exchange_keys_req->s_pk();
size_t spk_len = fbs_spk->size();
// transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
std::vector<std::vector<unsigned char>> cur_public_key;
std::vector<unsigned char> cpk(cpk_len);
std::vector<unsigned char> spk(spk_len);
bool ret_create_code_cpk = CreateArray<unsigned char>(&cpk, *fbs_cpk);
bool ret_create_code_spk = CreateArray<unsigned char>(&spk, *fbs_spk);
if (!(ret_create_code_cpk && ret_create_code_spk)) {
MS_LOG(ERROR) << "create cur_public_key failed";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, "update key or client failed",
next_req_time, cur_iterator);
return false;
}
cur_public_key.push_back(cpk);
cur_public_key.push_back(spk);
bool retcode_key =
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, fl_id, cur_public_key);
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
if (retcode_key && retcode_client) {
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
"Success, but the server is not ready yet.", next_req_time, cur_iterator);
BuildExchangeKeysRsp(fbb, schema::ResponseCode_SUCCEED, "Success, but the server is not ready yet.", next_req_time,
cur_iterator);
return true;
} else {
MS_LOG(ERROR) << "update key or client failed";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, "update key or client failed",
next_req_time, cur_iterator);
BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "update key or client failed", next_req_time,
cur_iterator);
return false;
}
}
void CipherKeys::BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder,
const schema::ResponseCode retcode, const std::string &reason,
const std::string &next_req_time, const int iteration) {
auto rsp_reason = exchange_keys_resp_builder->CreateString(reason);
auto rsp_next_req_time = exchange_keys_resp_builder->CreateString(next_req_time);
schema::ResponseExchangeKeysBuilder rsp_builder(*(exchange_keys_resp_builder.get()));
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time,
const int iteration) {
auto rsp_reason = fbb->CreateString(reason);
auto rsp_next_req_time = fbb->CreateString(next_req_time);
schema::ResponseExchangeKeysBuilder rsp_builder(*(fbb.get()));
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_next_req_time(rsp_next_req_time);
rsp_builder.add_iteration(iteration);
auto rsp_exchange_keys = rsp_builder.Finish();
exchange_keys_resp_builder->Finish(rsp_exchange_keys);
fbb->Finish(rsp_exchange_keys);
return;
}
bool CipherKeys::BuildGetKeys(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good) {
bool flag = true;
if (is_good) {
// convert client keys to standard keys list.
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
MS_LOG(INFO) << "Get Keys: ";
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
if (record_public_keys.size() < cipher_init_->client_num_need_) {
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
<< "clients num: " << cipher_init_->client_num_need_;
flag = false;
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
} else {
for (auto iter = record_public_keys.begin(); iter != record_public_keys.end(); ++iter) {
// read (fl_id, c_pk, s_pk) from the map: record_public_keys_
std::string fl_id = iter->first;
MS_LOG(INFO) << "fl_id : " << fl_id;
// To serialize the members to a new TableClientPublicKeys
auto fbs_fl_id = fbb->CreateString(fl_id);
auto fbs_c_pk = fbb->CreateVector(iter->second[0].data(), iter->second[0].size());
auto fbs_s_pk = fbb->CreateVector(iter->second[1].data(), iter->second[1].size());
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk);
public_keys_list.push_back(cur_public_key);
}
auto remote_publickeys = fbb->CreateVector(public_keys_list);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_remote_publickeys(remote_publickeys);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
}
} else {
void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good) {
if (!is_good) {
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
return;
}
return flag;
const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientsKeys);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
std::string fl_id = iter->first;
fl::KeysPb keys_pb = iter->second;
auto fbs_fl_id = fbb->CreateString(fl_id);
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size());
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
auto fbs_pw_iv = fbb->CreateVector(pw_iv.data(), pw_iv.size());
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
auto fbs_pw_salt = fbb->CreateVector(pw_salt.data(), pw_salt.size());
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
public_keys_list.push_back(cur_public_key);
}
auto remote_publickeys = fbb->CreateVector(public_keys_list);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_remote_publickeys(remote_publickeys);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
return;
}
void CipherKeys::ClearKeys() {

View File

@ -44,21 +44,19 @@ class CipherKeys {
// handle the client's request of get keys.
bool GetKeys(const int cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &get_exchange_keys_resp_builder);
const schema::GetExchangeKeys *get_exchange_keys_req, const std::shared_ptr<fl::server::FBBuilder> &fbb);
// handle the client's request of exchange keys.
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder);
const std::shared_ptr<fl::server::FBBuilder> &fbb);
// build response code of get keys.
bool BuildGetKeys(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good);
void BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good);
// build response code of exchange keys.
void BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder,
const schema::ResponseCode retcode, const std::string &reason,
const std::string &next_req_time, const int iteration);
void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time, const int iteration);
// clear the shared memory.
void ClearKeys();

View File

@ -51,7 +51,7 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect
}
void CipherMetaStorage::GetClientKeysFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) {
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) {
const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
@ -60,12 +60,33 @@ void CipherMetaStorage::GetClientKeysFromServer(
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
std::string fl_id = iter->first;
fl::KeysPb keys_pb = iter->second;
std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
std::vector<std::vector<unsigned char>> cur_keys;
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
std::vector<std::vector<uint8_t>> cur_keys;
cur_keys.push_back(cpk);
cur_keys.push_back(spk);
clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<unsigned char>>>(fl_id, cur_keys));
clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
}
}
void CipherMetaStorage::GetClientIVsFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) {
const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
std::string fl_id = iter->first;
fl::KeysPb keys_pb = iter->second;
std::vector<uint8_t> ind_iv(keys_pb.ind_iv().begin(), keys_pb.ind_iv().end());
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
std::vector<std::vector<uint8_t>> cur_ivs;
cur_ivs.push_back(ind_iv);
cur_ivs.push_back(pw_iv);
cur_ivs.push_back(pw_salt);
clients_ivs_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs));
}
}
@ -73,9 +94,18 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
const fl::PBMetadata &clients_noises_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
int count = 0;
int count_thld = 100;
while (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "GetClientNoisesFromServer NULL.";
std::this_thread::sleep_for(std::chrono::milliseconds(50));
int register_time = 500;
std::this_thread::sleep_for(std::chrono::milliseconds(register_time));
count++;
if (count >= count_thld) break;
}
MS_LOG(INFO) << "GetClientNoisesFromServer Count: " << count;
if (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(ERROR) << "GetClientNoisesFromServer NULL.";
return false;
}
cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(),
clients_noises_pb.one_client_noises().noise().end());
@ -98,14 +128,14 @@ bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char
}
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
bool retcode = true;
fl::FLId fl_id_pb;
fl_id_pb.set_fl_id(fl_id);
fl::PBMetadata client_pb;
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
return retcode;
}
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
MS_LOG(INFO) << "register prime: " << prime;
fl::Prime prime_id_pb;
@ -117,9 +147,9 @@ void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &
}
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
const std::vector<std::vector<unsigned char>> &cur_public_key) {
bool retcode = true;
if (cur_public_key.size() < 2) {
const std::vector<std::vector<uint8_t>> &cur_public_key) {
size_t correct_size = 2;
if (cur_public_key.size() < correct_size) {
MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
return false;
}
@ -132,7 +162,73 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
fl::PBMetadata client_and_keys_pb;
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
return retcode;
}
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
const schema::RequestExchangeKeys *exchange_keys_req) {
std::string fl_id = exchange_keys_req->fl_id()->str();
auto fbs_cpk = exchange_keys_req->c_pk();
auto fbs_spk = exchange_keys_req->s_pk();
if (fbs_cpk == nullptr || fbs_spk == nullptr) {
MS_LOG(ERROR) << "public key from exchange_keys_req is null";
return false;
}
size_t spk_len = fbs_spk->size();
size_t cpk_len = fbs_cpk->size();
// transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
std::vector<std::vector<uint8_t>> cur_public_key;
std::vector<uint8_t> cpk(cpk_len);
std::vector<uint8_t> spk(spk_len);
bool ret_create_code_cpk = CreateArray<uint8_t>(&cpk, *fbs_cpk);
bool ret_create_code_spk = CreateArray<uint8_t>(&spk, *fbs_spk);
if (!(ret_create_code_cpk && ret_create_code_spk)) {
MS_LOG(ERROR) << "create array for public keys failed";
return false;
}
cur_public_key.push_back(cpk);
cur_public_key.push_back(spk);
auto fbs_ind_iv = exchange_keys_req->ind_iv();
std::vector<char> ind_iv;
if (fbs_ind_iv == nullptr) {
MS_LOG(WARNING) << "ind_iv in exchange_keys_req is nullptr";
} else {
ind_iv.assign(fbs_ind_iv->begin(), fbs_ind_iv->end());
}
auto fbs_pw_iv = exchange_keys_req->pw_iv();
std::vector<char> pw_iv;
if (fbs_pw_iv == nullptr) {
MS_LOG(WARNING) << "pw_iv in exchange_keys_req is nullptr";
} else {
pw_iv.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
}
auto fbs_pw_salt = exchange_keys_req->pw_salt();
std::vector<char> pw_salt;
if (fbs_pw_salt == nullptr) {
MS_LOG(WARNING) << "pw_salt in exchange_keys_req is nullptr";
} else {
pw_salt.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
}
// update new item to memory server.
fl::KeysPb keys;
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
keys.set_ind_iv(ind_iv.data(), ind_iv.size());
keys.set_pw_iv(pw_iv.data(), pw_iv.size());
keys.set_pw_salt(pw_salt.data(), pw_salt.size());
fl::PairClientKeys pair_client_keys_pb;
pair_client_keys_pb.set_fl_id(fl_id);
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
fl::PBMetadata client_and_keys_pb;
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
return retcode;
}
@ -142,13 +238,13 @@ bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const s
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
fl::PBMetadata client_noises_pb;
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
return fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
return ret;
}
bool CipherMetaStorage::UpdateClientShareToServer(
const char *list_name, const std::string &fl_id,
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
bool retcode = true;
int size_shares = shares->size();
fl::SharesPb shares_pb;
for (int index = 0; index < size_shares; ++index) {
@ -166,14 +262,17 @@ bool CipherMetaStorage::UpdateClientShareToServer(
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
fl::PBMetadata client_and_shares_pb;
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
return retcode;
}
void CipherMetaStorage::RegisterClass() {
fl::PBMetadata exchange_kyes_client_list;
fl::PBMetadata exchange_keys_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
exchange_kyes_client_list);
exchange_keys_client_list);
fl::PBMetadata get_keys_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetKeysClientList,
get_keys_client_list);
fl::PBMetadata clients_keys;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
fl::PBMetadata reconstruct_client_list;
@ -185,9 +284,15 @@ void CipherMetaStorage::RegisterClass() {
fl::PBMetadata share_secretes_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
share_secretes_client_list);
fl::PBMetadata get_secretes_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetSecretsClientList,
get_secretes_client_list);
fl::PBMetadata clients_encrypt_shares;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
clients_encrypt_shares);
fl::PBMetadata get_update_clients_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetUpdateModelClientList,
get_update_clients_list);
}
} // namespace armour
} // namespace mindspore

View File

@ -31,23 +31,39 @@
#include "fl/server/distributed_metadata_store.h"
#include "fl/server/common.h"
#define IND_IV_INDEX 0
#define PW_IV_INDEX 1
#define PW_SALT_INDEX 2
namespace mindspore {
namespace armour {
template <typename T1>
bool CreateArray(std::vector<T1> *newData, const flatbuffers::Vector<T1> &fbs_arr) {
if (newData == nullptr) return false;
size_t size = newData->size();
size_t size_fbs_arr = fbs_arr.size();
if (size != size_fbs_arr) return false;
for (size_t i = 0; i < size; ++i) {
newData->at(i) = fbs_arr.Get(i);
}
return true;
}
constexpr int SHARE_MAX_SIZE = 256;
constexpr int SECRET_MAX_LEN_DOUBLE = 66;
struct clientshare_str {
std::string fl_id;
std::vector<unsigned char> share;
std::vector<uint8_t> share;
int index;
};
struct CipherPublicPara {
int t;
int g;
unsigned char prime[PRIME_MAX_LEN];
unsigned char p[SECRET_MAX_LEN];
uint8_t prime[PRIME_MAX_LEN];
uint8_t p[SECRET_MAX_LEN];
float dp_eps;
float dp_delta;
float dp_norm_clip;
@ -62,7 +78,7 @@ class CipherMetaStorage {
// Register Prime.
void RegisterPrime(const char *list_name, const std::string &prime);
// Get tprime from shared server.
bool GetPrimeFromServer(const char *prime_name, unsigned char *prime);
bool GetPrimeFromServer(const char *prime_name, uint8_t *prime);
// Get client shares from shared server.
void GetClientSharesFromServer(const char *list_name,
std::map<std::string, std::vector<clientshare_str>> *clients_shares_list);
@ -70,14 +86,18 @@ class CipherMetaStorage {
void GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list);
// Get client keys from shared server.
void GetClientKeysFromServer(const char *list_name,
std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list);
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list);
void GetClientIVsFromServer(const char *list_name,
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list);
// Get client noises from shared server.
bool GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise);
// Update client fl_id to shared server.
bool UpdateClientToServer(const char *list_name, const std::string &fl_id);
// Update client key to shared server.
bool UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
const std::vector<std::vector<unsigned char>> &cur_public_key);
const std::vector<std::vector<uint8_t>> &cur_public_key);
// Update client key with signature to shared server.
bool UpdateClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req);
// Update client noise to shared server.
bool UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise);
// Update client share to shared server.

View File

@ -16,33 +16,35 @@
#include "fl/armour/cipher/cipher_reconstruct.h"
#include "fl/server/common.h"
#include "fl/armour/secure_protocol/random.h"
#include "fl/armour/secure_protocol/masking.h"
#include "fl/armour/secure_protocol/key_agreement.h"
#include "fl/armour/cipher/cipher_meta_storage.h"
namespace mindspore {
namespace armour {
bool CipherReconStruct::CombineMask(
std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_keys,
const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
const std::vector<string> &client_list) {
bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp,
std::map<std::string, std::vector<float>> *client_noise,
const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
const std::vector<string> &client_list,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs) {
bool retcode = true;
#ifdef _WIN32
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
retcode = false;
#else
if (shares_tmp == nullptr || client_noise == nullptr) {
MS_LOG(ERROR) << "shares_tmp or client_noise is nullptr.";
return false;
}
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
// define flag_share: judge we need b or s
bool flag_share = true;
const std::string fl_id = iter->first;
std::vector<std::string>::const_iterator ptr = client_list.begin();
for (; ptr < client_list.end(); ++ptr) {
if (*ptr == fl_id) {
flag_share = false;
break;
}
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) {
// the client is online
flag_share = false;
}
MS_LOG(INFO) << "fl_id_src : " << fl_id;
BIGNUM *prime = BN_new();
@ -61,7 +63,6 @@ bool CipherReconStruct::CombineMask(
MS_LOG(ERROR) << "shares_tmp copy failed";
retcode = false;
}
MS_LOG(INFO) << "fl_id_des : " << (iter->second)[i].fl_id;
std::string print_share_data(reinterpret_cast<const char *>(shares_tmp->at(i)->data), shares_tmp->at(i)->len);
}
MS_LOG(INFO) << "end assign secrets shares to public shares ";
@ -74,23 +75,42 @@ bool CipherReconStruct::CombineMask(
MS_LOG(INFO) << "combine secrets shares Success.";
if (flag_share) {
MS_LOG(INFO) << "start get complete s_uv.";
// reconstruct pairwise noise
MS_LOG(INFO) << "start reconstruct pairwise noise.";
std::vector<float> noise(cipher_init_->featuremap_, 0.0);
if (GetSuvNoise(clients_share_list, record_public_keys, fl_id, &noise, secret, length) == false)
if (GetSuvNoise(clients_share_list, record_public_keys, client_ivs, fl_id, &noise, secret, length) == false)
retcode = false;
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
MS_LOG(INFO) << " fl_id : " << fl_id;
MS_LOG(INFO) << "end get complete s_uv.";
} else {
// reconstruct individual noise
MS_LOG(INFO) << "start reconstruct individual noise.";
std::vector<float> noise;
if (Random::RandomAESCTR(&noise, cipher_init_->featuremap_, (const unsigned char *)secret, SECRET_MAX_LEN) < 0)
auto it = client_ivs.find(fl_id);
if (it == client_ivs.end()) {
MS_LOG(ERROR) << "cannot get ivs for client: " << fl_id;
return false;
}
if (it->second.size() != IV_NUM) {
MS_LOG(ERROR) << "get " << it->second.size() << " ivs, the iv num required is: " << IV_NUM;
return false;
}
std::vector<uint8_t> ind_iv = it->second[0];
if (Masking::GetMasking(&noise, cipher_init_->featuremap_, (const uint8_t *)secret, SECRET_MAX_LEN,
ind_iv.data(), ind_iv.size()) < 0)
retcode = false;
for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) {
noise[index_noise] *= -1;
}
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
MS_LOG(INFO) << " fl_id : " << fl_id;
client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
}
} else {
MS_LOG(ERROR) << "reconstruct secret failed: the number of secret shares for fl_id: " << fl_id
<< " is not enough";
MS_LOG(ERROR) << "get " << iter->second.size()
<< "shares, however the secrets_minnums_ required is: " << cipher_init_->secrets_minnums_;
return false;
}
}
#endif
@ -98,173 +118,188 @@ bool CipherReconStruct::CombineMask(
}
bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &client_list) {
// get reconstruct_secret_list_ori from memory server
// get reconstruct_secrets from memory server
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
bool retcode = true;
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori;
std::map<std::string, std::vector<clientshare_str>> reconstruct_secrets;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&reconstruct_secret_list_ori);
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
&reconstruct_secrets);
std::map<std::string, std::vector<std::vector<uint8_t>>> record_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
std::vector<std::string> clients_reconstruct_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
&clients_reconstruct_list);
std::map<std::string, std::vector<std::vector<uint8_t>>> client_ivs;
cipher_init_->cipher_meta_storage_.GetClientIVsFromServer(fl::server::kCtxClientsKeys, &client_ivs);
std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list);
if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() ||
record_public_keys.size() < cipher_init_->client_num_need_ ||
clients_share_list.size() < cipher_init_->share_clients_num_need_) {
if (record_public_keys.size() < cipher_init_->exchange_key_threshold ||
clients_share_list.size() < cipher_init_->share_secrets_threshold ||
record_public_keys.size() != client_ivs.size()) {
MS_LOG(ERROR) << "send share client size: " << clients_share_list.size()
<< ", send public-key client size: " << record_public_keys.size()
<< ", send ivs client size: " << client_ivs.size();
MS_LOG(ERROR) << "get data from server memory failed";
return false;
}
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list;
ConvertSharesToShares(reconstruct_secret_list_ori, &reconstruct_secret_list);
if (!ConvertSharesToShares(reconstruct_secrets, &reconstruct_secret_list)) {
MS_LOG(ERROR) << "ConvertSharesToShares failed.";
return false;
}
MS_LOG(ERROR) << "recombined shares";
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
MS_LOG(ERROR) << "fl_id: " << iter->first;
MS_LOG(ERROR) << "share size: " << iter->second.size();
}
std::vector<Share *> shares_tmp;
if (MallocShares(&shares_tmp, cipher_init_->secrets_minnums_) == false) {
if (!MallocShares(&shares_tmp, cipher_init_->secrets_minnums_)) {
MS_LOG(ERROR) << "Reconstruct malloc shares_tmp invalid.";
return false;
}
MS_LOG(INFO) << "Reconstruct client list: ";
std::vector<std::string>::const_iterator ptr_tmp = client_list.begin();
for (; ptr_tmp < client_list.end(); ++ptr_tmp) {
MS_LOG(INFO) << *ptr_tmp;
}
MS_LOG(INFO) << "Reconstruct secrets shares: ";
std::map<std::string, std::vector<float>> client_keys;
retcode = CombineMask(&shares_tmp, &client_keys, clients_share_list, record_public_keys, reconstruct_secret_list,
client_list);
std::map<std::string, std::vector<float>> client_noise;
retcode = CombineMask(&shares_tmp, &client_noise, clients_share_list, record_public_keys, reconstruct_secret_list,
client_list, client_ivs);
DeleteShares(&shares_tmp);
if (retcode) {
std::vector<float> noise;
if (GetNoiseMasksSum(&noise, client_keys) == false) {
if (!GetNoiseMasksSum(&noise, client_noise)) {
MS_LOG(ERROR) << " GetNoiseMasksSum failed";
return false;
}
client_keys.clear();
client_noise.clear();
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise) == false)
if (!cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise)) {
MS_LOG(ERROR) << " ReconstructSecretsGenNoise failed. because UpdateClientNoiseToServer failed";
return false;
}
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
} else {
MS_LOG(INFO) << " ReconstructSecretsGenNoise failed. because gen noise inside failed";
MS_LOG(ERROR) << " ReconstructSecretsGenNoise failed. because gen noise inside failed";
}
return retcode;
}
// reconstruct secrets
bool CipherReconStruct::ReconstructSecrets(
const int cur_iterator, const std::string &next_req_time, const schema::SendReconstructSecret *reconstruct_secret_req,
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
const std::vector<std::string> &client_list) {
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb,
const std::vector<std::string> &client_list) {
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
clock_t start_time = clock();
if (reconstruct_secret_req == nullptr || reconstruct_secret_resp_builder == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr. ";
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_RequestError,
"Request is nullptr or Response builder is nullptr.", cur_iterator, next_req_time);
return false;
}
if (client_list.size() < cipher_init_->reconstruct_clients_num_need_) {
MS_LOG(ERROR) << "illegal parameters. update model client_list size: " << client_list.size();
BuildReconstructSecretsRsp(
reconstruct_secret_resp_builder, schema::ResponseCode_RequestError,
"illegal parameters: update model client_list size must larger than reconstruct_clients_num_need", cur_iterator,
next_req_time);
return false;
}
std::vector<std::string> clients_reconstruct_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
&clients_reconstruct_list);
std::map<std::string, std::vector<clientshare_str>> clients_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&clients_shares_all);
size_t count_client_num = clients_shares_all.size();
if (count_client_num != clients_reconstruct_list.size()) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
"shares client size and client size are not equal.", cur_iterator, next_req_time);
MS_LOG(ERROR) << "shares client size and client size are not equal.";
if (reconstruct_secret_req == nullptr) {
std::string reason = "Request is nullptr";
MS_LOG(ERROR) << reason;
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, cur_iterator, next_req_time);
return false;
}
int iterator = reconstruct_secret_req->iteration();
std::string fl_id = reconstruct_secret_req->fl_id()->str();
if (iterator != cur_iterator) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"The iteration round of the client does not match the current iteration.", cur_iterator,
next_req_time);
MS_LOG(ERROR) << "Client " << fl_id << " The iteration round of the client does not match the current iteration.";
return false;
}
if (find(client_list.begin(), client_list.end(), fl_id) == client_list.end()) { // client not in client list.
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
"The client is not in update model client list.", cur_iterator, next_req_time);
MS_LOG(ERROR) << "The client " << fl_id << " is not in update model client list.";
if (client_list.size() < cipher_init_->reconstruct_secrets_threshold) {
MS_LOG(ERROR) << "illegal parameters. update model client_list size: " << client_list.size();
BuildReconstructSecretsRsp(
fbb, schema::ResponseCode_RequestError,
"illegal parameters: update model client_list size must larger than reconstruct_clients_num_need", cur_iterator,
next_req_time);
return false;
}
if (find(clients_reconstruct_list.begin(), clients_reconstruct_list.end(), fl_id) != clients_reconstruct_list.end()) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
"Client has sended messages.", cur_iterator, next_req_time);
std::vector<std::string> get_clients_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetUpdateModelClientList,
&get_clients_list);
// client not in get client list.
if (find(get_clients_list.begin(), get_clients_list.end(), fl_id) == get_clients_list.end()) {
std::string reason;
MS_LOG(INFO) << "The client " << fl_id << " is not in get update model client list.";
// client in update model client list.
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) {
reason = "The client " + fl_id + " is not in get clients list, but in update model client list.";
MS_LOG(INFO) << reason;
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, reason, cur_iterator, next_req_time);
return false;
}
reason = "The client " + fl_id + " is not in get clients list, and not in update model client list.";
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, "The client is not in update model client list.",
cur_iterator, next_req_time);
return false;
}
std::map<std::string, std::vector<clientshare_str>> reconstruct_shares;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&reconstruct_shares);
size_t count_client_num = reconstruct_shares.size();
if (reconstruct_shares.find(fl_id) != reconstruct_shares.end()) {
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Client has sended messages.", cur_iterator,
next_req_time);
MS_LOG(INFO) << "Error, client " << fl_id << " has sended messages.";
return false;
}
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id);
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
if (!(retcode_client && retcode_share)) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
"reconstruct update shares or client failed.", cur_iterator, next_req_time);
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "reconstruct update shares or client failed.",
cur_iterator, next_req_time);
MS_LOG(ERROR) << "reconstruct update shares or client failed.";
return false;
}
count_client_num = count_client_num + 1;
if (count_client_num < cipher_init_->reconstruct_clients_num_need_) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
"Success,but the server is not ready to reconstruct secret yet.", cur_iterator,
if (count_client_num < cipher_init_->reconstruct_secrets_threshold) {
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED,
"Success, but the server is not ready to reconstruct secret yet.", cur_iterator,
next_req_time);
MS_LOG(INFO) << "ReconstructSecrets" << fl_id << " Success, but count " << count_client_num << "is not enough.";
MS_LOG(INFO) << "ReconstructSecrets " << fl_id << " Success, but count " << count_client_num << "is not enough.";
return true;
} else {
bool retcode_result = true;
const fl::PBMetadata &clients_noises_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
if (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "Success,the secret will be reconstructed.";
retcode_result = ReconstructSecretsGenNoise(client_list);
if (retcode_result) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
"Success,the secret is reconstructing.", cur_iterator, next_req_time);
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success, reconstruct ok.";
} else {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
"the secret restructs failed.", cur_iterator, next_req_time);
MS_LOG(ERROR) << "the secret restructs failed.";
}
} else {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
"Clients' number is full.", cur_iterator, next_req_time);
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success : no need reconstruct.";
}
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "Reconstruct get + gennoise data time is : " << duration;
return retcode_result;
}
const fl::PBMetadata &clients_noises_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
if (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "Success, the secret will be reconstructed.";
if (ReconstructSecretsGenNoise(client_list)) {
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Success,the secret is reconstructing.",
cur_iterator, next_req_time);
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success, reconstruct ok.";
} else {
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "the secret restructs failed.", cur_iterator,
next_req_time);
MS_LOG(ERROR) << "CipherReconStruct::ReconstructSecrets" << fl_id << " failed.";
}
} else {
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Clients' number is full.", cur_iterator,
next_req_time);
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success : no need reconstruct.";
}
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "Reconstruct get + gennoise data time is : " << duration;
return true;
}
bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
const std::map<std::string, std::vector<float>> &client_keys) {
const std::map<std::string, std::vector<float>> &client_noise) {
std::vector<float> sum(cipher_init_->featuremap_, 0.0);
for (auto iter = client_keys.begin(); iter != client_keys.end(); iter++) {
for (auto iter = client_noise.begin(); iter != client_noise.end(); iter++) {
if (iter->second.size() != cipher_init_->featuremap_) {
return false;
}
@ -301,35 +336,52 @@ void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::ser
return;
}
bool CipherReconStruct::GetSuvNoise(
const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys, const string &fl_id,
std::vector<float> *noise, uint8_t *secret, int length) {
bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs,
const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length) {
for (auto p_key = clients_share_list.begin(); p_key != clients_share_list.end(); ++p_key) {
if (*p_key != fl_id) {
PrivateKey *privKey1 = KeyAgreement::FromPrivateBytes((unsigned char *)secret, length);
if (privKey1 == NULL) {
MS_LOG(ERROR) << "create privKey1 failed\n";
PrivateKey *privKey = KeyAgreement::FromPrivateBytes(secret, length);
if (privKey == NULL) {
MS_LOG(ERROR) << "create privKey failed\n";
return false;
}
std::vector<unsigned char> public_key = record_public_keys.at(*p_key)[1];
PublicKey *pubKey1 = KeyAgreement::FromPublicBytes(public_key.data(), public_key.size());
if (pubKey1 == NULL) {
MS_LOG(ERROR) << "create pubKey1 failed\n";
std::vector<uint8_t> public_key = record_public_keys.at(*p_key)[1];
std::string iv_fl_id;
if (fl_id < *p_key) {
iv_fl_id = fl_id;
} else {
iv_fl_id = *p_key;
}
auto iter = client_ivs.find(iv_fl_id);
if (iter == client_ivs.end()) {
MS_LOG(ERROR) << "cannot get ivs for client: " << iv_fl_id;
return false;
}
MS_LOG(INFO) << "fl_id : " << fl_id << "other id : " << *p_key;
unsigned char secret1[SECRET_MAX_LEN] = {0};
unsigned char salt[SECRET_MAX_LEN] = {0};
if (KeyAgreement::ComputeSharedKey(privKey1, pubKey1, SECRET_MAX_LEN, salt, SECRET_MAX_LEN, secret1) < 0) {
if (iter->second.size() != IV_NUM) {
MS_LOG(ERROR) << "get " << iter->second.size() << " ivs, the iv num required is: " << IV_NUM;
return false;
}
std::vector<uint8_t> pw_iv = iter->second[PW_IV_INDEX];
std::vector<uint8_t> pw_salt = iter->second[PW_SALT_INDEX];
PublicKey *pubKey = KeyAgreement::FromPublicBytes(public_key.data(), public_key.size());
if (pubKey == NULL) {
MS_LOG(ERROR) << "create pubKey failed\n";
return false;
}
MS_LOG(INFO) << "private_key fl_id : " << fl_id << " public_key fl_id : " << *p_key;
uint8_t secret1[SECRET_MAX_LEN] = {0};
if (KeyAgreement::ComputeSharedKey(privKey, pubKey, SECRET_MAX_LEN, pw_salt.data(), pw_salt.size(), secret1) <
0) {
MS_LOG(ERROR) << "ComputeSharedKey failed\n";
return false;
}
std::vector<float> noise_tmp;
if (Random::RandomAESCTR(&noise_tmp, cipher_init_->featuremap_, (const unsigned char *)secret1, SECRET_MAX_LEN) <
0) {
MS_LOG(ERROR) << "RandomAESCTR failed\n";
if (Masking::GetMasking(&noise_tmp, cipher_init_->featuremap_, (const uint8_t *)secret1, SECRET_MAX_LEN,
pw_iv.data(), pw_iv.size()) < 0) {
MS_LOG(ERROR) << "Get Masking failed\n";
return false;
}
bool symbol_noise = GetSymbol(fl_id, *p_key);
@ -345,9 +397,6 @@ bool CipherReconStruct::GetSuvNoise(
noise->at(index) += noise_tmp[index];
}
}
for (int i = 0; i < 5; i++) {
MS_LOG(INFO) << "index " << i << " : " << noise_tmp[i];
}
}
}
return true;
@ -361,37 +410,41 @@ bool CipherReconStruct::GetSymbol(const std::string &str1, const std::string &st
}
}
void CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
// recombined shares by their source fl_id (ownners)
bool CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
std::map<std::string, std::vector<clientshare_str>> *des) {
for (auto iter_ori = src.begin(); iter_ori != src.end(); ++iter_ori) {
std::string fl_des = iter_ori->first;
auto &cur_clientshare_str = iter_ori->second;
if (des == nullptr) return false;
for (auto iter = src.begin(); iter != src.end(); ++iter) {
std::string des_id = iter->first;
auto &cur_clientshare_str = iter->second;
for (size_t index_clientshare = 0; index_clientshare < cur_clientshare_str.size(); ++index_clientshare) {
std::string fl_src = cur_clientshare_str[index_clientshare].fl_id;
std::string src_id = cur_clientshare_str[index_clientshare].fl_id;
clientshare_str value;
value.fl_id = fl_des;
value.fl_id = des_id;
value.share = cur_clientshare_str[index_clientshare].share;
value.index = cur_clientshare_str[index_clientshare].index;
if (des->find(fl_src) == des->end()) { // fl_id_des is not in reconstruct_secret_list_
if (des->find(src_id) == des->end()) { // src_id is not in recombined shares list
std::vector<clientshare_str> value_list;
value_list.push_back(value);
des->insert(std::pair<std::string, std::vector<clientshare_str>>(fl_src, value_list));
} else { // fl_id_des is in reconstruct_secret_list_
des->at(fl_src).push_back(value);
des->insert(std::pair<std::string, std::vector<clientshare_str>>(src_id, value_list));
} else {
des->at(src_id).push_back(value);
}
}
}
return true;
}
bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int shares_size) {
if (shares_tmp == nullptr) return false;
for (int i = 0; i < shares_size; ++i) {
Share *share_i = new Share;
Share *share_i = new Share();
if (share_i == nullptr) {
MS_LOG(ERROR) << "shares_tmp " << i << " memory to cipher is invalid.";
DeleteShares(shares_tmp);
return false;
}
share_i->data = new unsigned char[SHARE_MAX_SIZE];
share_i->data = new uint8_t[SHARE_MAX_SIZE];
if (share_i->data == nullptr) {
MS_LOG(ERROR) << "shares_tmp's data " << i << " memory to cipher is invalid.";
DeleteShares(shares_tmp);
@ -405,6 +458,7 @@ bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int share
}
void CipherReconStruct::DeleteShares(std::vector<Share *> *shares_tmp) {
if (shares_tmp == nullptr) return;
if (shares_tmp->size() != 0) {
for (size_t i = 0; i < shares_tmp->size(); ++i) {
if (shares_tmp->at(i) != nullptr && shares_tmp->at(i)->data != nullptr) {

View File

@ -28,6 +28,8 @@
#include "fl/armour/cipher/cipher_init.h"
#include "fl/armour/cipher/cipher_meta_storage.h"
#define IV_NUM 3
namespace mindspore {
namespace armour {
// The process of reconstruct secret mask in the secure aggregation
@ -44,7 +46,7 @@ class CipherReconStruct {
// reconstruct secret mask
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req,
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
const std::shared_ptr<fl::server::FBBuilder> &fbb,
const std::vector<std::string> &client_list);
// build response code of reconstruct secret.
@ -60,26 +62,28 @@ class CipherReconStruct {
bool GetSymbol(const std::string &str1, const std::string &str2);
// get suv noise by computing shares result.
bool GetSuvNoise(const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length);
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs, const string &fl_id,
std::vector<float> *noise, uint8_t *secret, int length);
// malloc shares.
bool MallocShares(std::vector<Share *> *shares_tmp, int shares_size);
// delete shares.
void DeleteShares(std::vector<Share *> *shares_tmp);
// convert shares from receiving clients to sending clients.
void ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
bool ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
std::map<std::string, std::vector<clientshare_str>> *des);
// generate noise from shares.
bool ReconstructSecretsGenNoise(const std::vector<string> &client_list);
// get noise masks sum.
bool GetNoiseMasksSum(std::vector<float> *result, const std::map<std::string, std::vector<float>> &client_keys);
bool GetNoiseMasksSum(std::vector<float> *result, const std::map<std::string, std::vector<float>> &client_noise);
// combine noise mask.
bool CombineMask(std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_keys,
bool CombineMask(std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_noise,
const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
const std::vector<string> &client_list);
const std::vector<string> &client_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &client_ivs);
};
} // namespace armour
} // namespace mindspore

View File

@ -25,8 +25,8 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
const string next_req_time) {
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
if (share_secrets_req == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
std::string reason = "Request is nullptr or Response builder is nullptr.";
std::string reason = "Request is nullptr";
MS_LOG(ERROR) << reason;
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time,
cur_iterator);
return false;
@ -34,38 +34,34 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
// step 1: get client list and share secrets from memory server.
clock_t start_time = clock();
int iteration = share_secrets_req->iteration();
std::vector<std::string> get_keys_clients;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetKeysClientList, &get_keys_clients);
std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list);
std::vector<std::string> clients_exchange_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList,
&clients_exchange_list);
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
&encrypted_shares_all);
MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size()
<< "client of shares size : " << clients_share_list.size() << "shares size"
<< encrypted_shares_all.size();
if (encrypted_shares_all.size() != clients_share_list.size()) {
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime,
"client of shares and shares size are not equal", next_req_time, cur_iterator);
MS_LOG(ERROR) << "client of shares and shares size are not equal. client of shares size : "
<< clients_share_list.size() << "shares size" << encrypted_shares_all.size();
}
MS_LOG(INFO) << "Client of get keys size : " << get_keys_clients.size()
<< "client of update shares size : " << clients_share_list.size()
<< "updated shares size: " << encrypted_shares_all.size();
// step 2: update new item to memory server. serialise: update pb struct to memory server.
int iteration = share_secrets_req->iteration();
std::string fl_id_src = share_secrets_req->fl_id()->str();
if (find(clients_exchange_list.begin(), clients_exchange_list.end(), fl_id_src) ==
clients_exchange_list.end()) { // the client not in clients_exchange_list, return false.
if (find(get_keys_clients.begin(), get_keys_clients.end(), fl_id_src) == get_keys_clients.end()) {
// the client not in get keys clients
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError,
("client share secret is not in clients_exchange list. && client is illegal"), next_req_time,
("client share secret is not in getkeys list. && client is illegal"), next_req_time,
iteration);
return false;
}
if (find(clients_share_list.begin(), clients_share_list.end(), fl_id_src) !=
clients_share_list.end()) { // the client is already exists, return false.
if (encrypted_shares_all.find(fl_id_src) != encrypted_shares_all.end()) { // the client is already exists
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED,
("client sharesecret already exists."), next_req_time, iteration);
return false;
@ -74,24 +70,22 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
// update new item to memory server.
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
(share_secrets_req->encrypted_shares());
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src);
bool retcode = retcode_share && retcode_client;
if (retcode) {
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
MS_LOG(INFO) << "CipherShares::ShareSecrets Success";
} else {
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
if (!(retcode_share && retcode_client)) {
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime,
"update client of shares and shares failed", next_req_time, iteration);
MS_LOG(ERROR) << "CipherShares::ShareSecrets update client of shares and shares failed ";
return false;
}
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
MS_LOG(INFO) << "CipherShares::ShareSecrets Success";
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "ShareSecrets get + deal + update data time is : " << duration;
return retcode;
return true;
}
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
@ -101,36 +95,38 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
clock_t start_time = clock();
// step 0: check whether the parameters are legal.
if (get_secrets_req == nullptr) {
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, 0);
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, nullptr);
MS_LOG(ERROR) << "GetSecrets: get_secrets_req is nullptr.";
return false;
}
// step 1: get client list and client shares list from memory server.
std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list);
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
&encrypted_shares_all);
int iteration = get_secrets_req->iteration();
size_t share_clients_num = clients_share_list.size();
size_t cients_has_shares = encrypted_shares_all.size();
if (share_clients_num != cients_has_shares) {
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_OutOfTime, iteration, next_req_time, 0);
MS_LOG(ERROR) << "cients_has_shares: " << cients_has_shares << "share_clients_num: " << share_clients_num;
}
if (cipher_init_->share_clients_num_need_ > share_clients_num) { // the client num is not enough, return false.
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, 0);
MS_LOG(INFO) << "GetSecrets: the client num is not enough: share_clients_num_need_: "
<< cipher_init_->share_clients_num_need_ << "share_clients_num: " << share_clients_num;
size_t encrypted_shares_num = encrypted_shares_all.size();
if (cipher_init_->share_secrets_threshold > encrypted_shares_num) { // the client num is not enough, return false.
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
MS_LOG(INFO) << "GetSecrets: the encrypted shares num is not enough: share_secrets_threshold: "
<< cipher_init_->share_secrets_threshold << "encrypted_shares_num: " << encrypted_shares_num;
return false;
}
std::string fl_id = get_secrets_req->fl_id()->str();
if (find(clients_share_list.begin(), clients_share_list.end(), fl_id) ==
clients_share_list.end()) { // the client is not in client list, return false.
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, 0);
MS_LOG(ERROR) << "GetSecrets: client is not in client list.";
// the client is not in share secrets client list.
if (encrypted_shares_all.find(fl_id) == encrypted_shares_all.end()) {
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, nullptr);
MS_LOG(ERROR) << "GetSecrets: client is not in share secrets client list.";
return false;
}
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetSecretsClientList, fl_id);
if (!retcode_client) {
MS_LOG(ERROR) << "update get secrets clients failed";
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
return false;
}
// get the result client shares.
@ -171,7 +167,6 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SUCCEED, iteration, next_req_time,
&encrypted_shares);
MS_LOG(INFO) << "CipherShares::GetSecrets Success";
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
@ -185,7 +180,7 @@ void CipherShares::BuildGetSecretsRsp(
int rsp_retcode = retcode;
int rsp_iteration = iteration;
auto rsp_next_req_time = get_secrets_resp_builder->CreateString(next_req_time);
if (encrypted_shares == 0) {
if (encrypted_shares == nullptr) {
auto get_secrets_rsp =
schema::CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration, 0, rsp_next_req_time);
get_secrets_resp_builder->Finish(get_secrets_rsp);
@ -195,7 +190,6 @@ void CipherShares::BuildGetSecretsRsp(
encrypted_shares_rsp, rsp_next_req_time);
get_secrets_resp_builder->Finish(get_secrets_rsp);
}
return;
}

View File

@ -26,8 +26,8 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
clock_t start_time = clock();
std::vector<float> noise;
(void)cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
if (noise.size() != cipher_init_->featuremap_) {
bool ret = cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
if (!ret || noise.size() != cipher_init_->featuremap_) {
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
return false;
}

View File

@ -18,12 +18,7 @@
namespace mindspore {
namespace armour {
#define KEY_STEP_MAX 32
#define KEY_STEP_MIN 16
#define PAD_SIZE 5
AESEncrypt::AESEncrypt(const unsigned char *key, int key_len, unsigned char *ivec, int ivec_len, const AES_MODE mode) {
AESEncrypt::AESEncrypt(const uint8_t *key, int key_len, const uint8_t *ivec, int ivec_len, const AES_MODE mode) {
privKey = key;
privKeyLen = key_len;
iVec = ivec;
@ -47,12 +42,20 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
#else
int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) {
int ret;
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) {
MS_LOG(ERROR) << "key length must be 16 or 32!";
if (privKey == NULL || iVec == NULL) {
MS_LOG(ERROR) << "private key or init vector is invalid.";
return -1;
}
if (iVecLen != INIT_VEC_SIZE) {
MS_LOG(ERROR) << "initial vector size must be 16!";
if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
MS_LOG(ERROR) << "key length is invalid.";
return -1;
}
if (iVecLen != AES_IV_SIZE) {
MS_LOG(ERROR) << "initial vector size is invalid.";
return -1;
}
if (data == NULL || len <= 0 || encrypt_data == NULL) {
MS_LOG(ERROR) << "input data is invalid.";
return -1;
}
if (aesMode == AES_CBC || aesMode == AES_CTR) {
@ -69,12 +72,20 @@ int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned c
int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) {
int ret = 0;
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) {
MS_LOG(ERROR) << "key length must be 16 or 32!";
if (privKey == NULL || iVec == NULL) {
MS_LOG(ERROR) << "private key or init vector is invalid.";
return -1;
}
if (iVecLen != INIT_VEC_SIZE) {
MS_LOG(ERROR) << "initial vector size must be 16!";
if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
MS_LOG(ERROR) << "key length is invalid.";
return -1;
}
if (iVecLen != AES_IV_SIZE) {
MS_LOG(ERROR) << "initial vector size is invalid.";
return -1;
}
if (data == NULL || encrypt_len <= 0 || encrypt_data == NULL) {
MS_LOG(ERROR) << "input data is invalid.";
return -1;
}
if (aesMode == AES_CBC || aesMode == AES_CTR) {
@ -88,17 +99,21 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
return 0;
}
int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const unsigned char *key, unsigned char *ivec,
unsigned char *encrypt_data, int *encrypt_len) {
int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *encrypt_data, int *encrypt_len) {
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
return -1;
}
int out_len;
int ret = 0;
int ret;
if (aesMode == AES_CBC) {
switch (privKeyLen) {
case 16:
case KEY_LENGTH_16:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
break;
case 32:
case KEY_LENGTH_32:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
break;
default:
@ -107,16 +122,16 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
}
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex CBC fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
EVP_CIPHER_CTX_set_key_length(ctx, EVP_MAX_KEY_LENGTH);
EVP_CIPHER_CTX_set_padding(ctx, PAD_SIZE);
EVP_CIPHER_CTX_set_padding(ctx, EVP_PADDING_PKCS7);
} else if (aesMode == AES_CTR) {
switch (privKeyLen) {
case 16:
case KEY_LENGTH_16:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
break;
case 32:
case KEY_LENGTH_32:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
break;
default:
@ -125,21 +140,25 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
}
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex CTR fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
} else {
MS_LOG(ERROR) << "Unsupported AES mode";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
ret = EVP_EncryptUpdate(ctx, encrypt_data, &out_len, data, len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptUpdate fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
*encrypt_len = out_len;
ret = EVP_EncryptFinal_ex(ctx, encrypt_data + *encrypt_len, &out_len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptFinal_ex fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
*encrypt_len += out_len;
@ -147,17 +166,21 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
return 0;
}
int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len, const unsigned char *key,
unsigned char *ivec, unsigned char *decrypt_data, int *decrypt_len) {
int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *decrypt_data, int *decrypt_len) {
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
return -1;
}
int out_len;
int ret = 0;
int ret;
if (aesMode == AES_CBC) {
switch (privKeyLen) {
case 16:
case KEY_LENGTH_16:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
break;
case 32:
case KEY_LENGTH_32:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
break;
default:
@ -165,40 +188,46 @@ int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len
ret = -1;
}
if (ret != 1) {
EVP_CIPHER_CTX_free(ctx);
return -1;
}
EVP_CIPHER_CTX_set_key_length(ctx, EVP_MAX_KEY_LENGTH);
} else if (aesMode == AES_CTR) {
switch (privKeyLen) {
case 16:
case KEY_LENGTH_16:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
break;
case 32:
case KEY_LENGTH_32:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
break;
default:
MS_LOG(ERROR) << "key length is incorrect!";
ret = -1;
}
if (ret != 1) {
EVP_CIPHER_CTX_free(ctx);
return -1;
}
} else {
ret = -1;
}
if (ret != 1) {
MS_LOG(ERROR) << "Unsupported AES mode";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
ret = EVP_DecryptUpdate(ctx, decrypt_data, &out_len, encrypt_data, len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptUpdate fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
*decrypt_len = out_len;
ret = EVP_DecryptFinal_ex(ctx, decrypt_data + *decrypt_len, &out_len);
if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptFinal_ex fail!";
EVP_CIPHER_CTX_free(ctx);
return -1;
}
*decrypt_len += out_len;
EVP_CIPHER_CTX_free(ctx);
return 0;
}
#endif

View File

@ -22,7 +22,9 @@
#endif
#include "utils/log_adapter.h"
#define INIT_VEC_SIZE 16
#define AES_IV_SIZE 16
#define KEY_LENGTH_32 32
#define KEY_LENGTH_16 16
namespace mindspore {
namespace armour {
@ -35,21 +37,21 @@ class SymmetricEncrypt : Encrypt {};
class AESEncrypt : SymmetricEncrypt {
public:
AESEncrypt(const unsigned char *key, int key_len, unsigned char *ivec, int ivec_len, AES_MODE mode);
AESEncrypt(const unsigned char *key, int key_len, const uint8_t *ivec, int ivec_len, AES_MODE mode);
~AESEncrypt();
int EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len);
int DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len);
int EncryptData(const uint8_t *data, const int len, uint8_t *encrypt_data, int *encrypt_len);
int DecryptData(const uint8_t *encrypt_data, const int encrypt_len, uint8_t *data, int *len);
private:
const unsigned char *privKey;
const uint8_t *privKey;
int privKeyLen;
unsigned char *iVec;
const uint8_t *iVec;
int iVecLen;
AES_MODE aesMode;
int evp_aes_encrypt(const unsigned char *data, const int len, const unsigned char *key, unsigned char *ivec,
unsigned char *encrypt_data, int *encrypt_len);
int evp_aes_decrypt(const unsigned char *encrypt_data, const int len, const unsigned char *key, unsigned char *ivec,
unsigned char *decrypt_data, int *decrypt_len);
int evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *encrypt_data, int *encrypt_len);
int evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
uint8_t *decrypt_data, int *decrypt_len);
};
} // namespace armour

View File

@ -54,14 +54,22 @@ PrivateKey::PrivateKey(EVP_PKEY *evpKey) { evpPrivKey = evpKey; }
PrivateKey::~PrivateKey() { EVP_PKEY_free(evpPrivKey); }
int PrivateKey::GetPrivateBytes(size_t *len, unsigned char *privKeyBytes) {
int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) {
if (privKeyBytes == nullptr || len <= 0) {
MS_LOG(ERROR) << "input privKeyBytes invalid.";
return -1;
}
if (!EVP_PKEY_get_raw_private_key(evpPrivKey, privKeyBytes, len)) {
return -1;
}
return 0;
}
int PrivateKey::GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) {
int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) {
if (pubKeyBytes == nullptr || len <= 0) {
MS_LOG(ERROR) << "input pubKeyBytes invalid.";
return -1;
}
if (!EVP_PKEY_get_raw_public_key(evpPrivKey, pubKeyBytes, len)) {
return -1;
}
@ -69,7 +77,19 @@ int PrivateKey::GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) {
}
int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned char *salt, int salt_len,
unsigned char *exchangeKey) {
uint8_t *exchangeKey) {
if (peerPublicKey == nullptr) {
MS_LOG(ERROR) << "peerPublicKey is nullptr.";
return -1;
}
if (key_len != KEY_LEN || exchangeKey == nullptr) {
MS_LOG(ERROR) << "exchangeKey is nullptr or input key_len is incorrect.";
return -1;
}
if (salt == nullptr || salt_len != SALT_LEN) {
MS_LOG(ERROR) << "input salt in invalid.";
return -1;
}
EVP_PKEY_CTX *ctx;
size_t len = 0;
ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL);
@ -79,30 +99,38 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c
}
if (EVP_PKEY_derive_init(ctx) <= 0) {
MS_LOG(ERROR) << "EVP_PKEY_derive_init failed!";
EVP_PKEY_CTX_free(ctx);
return -1;
}
if (EVP_PKEY_derive_set_peer(ctx, peerPublicKey->evpPubKey) <= 0) {
MS_LOG(ERROR) << "EVP_PKEY_derive_set_peer failed!";
EVP_PKEY_CTX_free(ctx);
return -1;
}
unsigned char *secret;
if (EVP_PKEY_derive(ctx, NULL, &len) <= 0) {
MS_LOG(ERROR) << "get derive key size failed!";
EVP_PKEY_CTX_free(ctx);
return -1;
}
secret = (unsigned char *)OPENSSL_malloc(len);
if (!secret) {
MS_LOG(ERROR) << "malloc secret memory failed!";
EVP_PKEY_CTX_free(ctx);
return -1;
}
if (EVP_PKEY_derive(ctx, secret, &len) <= 0) {
MS_LOG(ERROR) << "derive key failed!";
OPENSSL_free(secret);
EVP_PKEY_CTX_free(ctx);
return -1;
}
if (!PKCS5_PBKDF2_HMAC((char *)secret, len, salt, salt_len, ITERATION, EVP_sha256(), key_len, exchangeKey)) {
if (!PKCS5_PBKDF2_HMAC(reinterpret_cast<char *>(secret), len, salt, salt_len, ITERATION, EVP_sha256(), key_len,
exchangeKey)) {
OPENSSL_free(secret);
EVP_PKEY_CTX_free(ctx);
return -1;
}
OPENSSL_free(secret);
@ -118,9 +146,11 @@ PrivateKey *KeyAgreement::GeneratePrivKey() {
return NULL;
}
if (EVP_PKEY_keygen_init(pctx) <= 0) {
EVP_PKEY_CTX_free(pctx);
return NULL;
}
if (EVP_PKEY_keygen(pctx, &evpKey) <= 0) {
EVP_PKEY_CTX_free(pctx);
return NULL;
}
EVP_PKEY_CTX_free(pctx);
@ -131,14 +161,30 @@ PrivateKey *KeyAgreement::GeneratePrivKey() {
PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
unsigned char *pubKeyBytes;
size_t len = 0;
if (privKey == nullptr) {
return NULL;
}
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, NULL, &len)) {
return NULL;
}
pubKeyBytes = (unsigned char *)OPENSSL_malloc(len);
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, pubKeyBytes, &len)) {
pubKeyBytes = reinterpret_cast<uint8_t *>(OPENSSL_malloc(len));
if (!pubKeyBytes) {
MS_LOG(ERROR) << "malloc secret memory failed!";
return NULL;
}
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, pubKeyBytes, &len)) {
MS_LOG(ERROR) << "EVP_PKEY_get_raw_public_key failed!";
OPENSSL_free(pubKeyBytes);
return NULL;
}
EVP_PKEY *evp_pubKey =
EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, reinterpret_cast<uint8_t *>(pubKeyBytes), len);
if (evp_pubKey == NULL) {
MS_LOG(ERROR) << "EVP_PKEY_new_raw_public_key failed!";
OPENSSL_free(pubKeyBytes);
return NULL;
}
EVP_PKEY *evp_pubKey = EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, (unsigned char *)pubKeyBytes, len);
OPENSSL_free(pubKeyBytes);
PublicKey *pubKey = new PublicKey(evp_pubKey);
return pubKey;
@ -147,6 +193,7 @@ PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) {
EVP_PKEY *evp_Key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, NULL, data, len);
if (evp_Key == NULL) {
MS_LOG(ERROR) << "create evp_Key from raw bytes failed!";
return NULL;
}
PrivateKey *privKey = new PrivateKey(evp_Key);
@ -164,7 +211,11 @@ PublicKey *KeyAgreement::FromPublicBytes(unsigned char *data, int len) {
}
int KeyAgreement::ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey, int key_len,
const unsigned char *salt, int salt_len, unsigned char *exchangeKey) {
const unsigned char *salt, int salt_len, uint8_t *exchangeKey) {
if (privKey == nullptr) {
MS_LOG(ERROR) << "privKey is nullptr!";
return -1;
}
return privKey->Exchange(peerPublicKey, key_len, salt, salt_len, exchangeKey);
}
#endif

View File

@ -24,7 +24,8 @@
#endif
#include "utils/log_adapter.h"
#define KEK_KEY_LEN 32
#define KEY_LEN 32
#define SALT_LEN 32
#define ITERATION 10000
namespace mindspore {

View File

@ -14,44 +14,39 @@
* limitations under the License.
*/
#include "fl/armour/secure_protocol/random.h"
#include "fl/armour/secure_protocol/masking.h"
namespace mindspore {
namespace armour {
Random::Random(size_t init_seed) { generator.seed(init_seed); }
Random::~Random() {}
#ifdef _WIN32
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
return -1;
}
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
int Masking::GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *seed, int seed_len,
const uint8_t *ivec, int ivec_size) {
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
return -1;
}
#else
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
int retval = RAND_priv_bytes(secret, num_bytes);
return retval;
}
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
if (seed_len != 16 && seed_len != 32) {
MS_LOG(ERROR) << "seed length must be 16 or 32!";
int Masking::GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *secret, int secret_len,
const uint8_t *ivec, int ivec_size) {
if ((secret_len != KEY_LENGTH_16 && secret_len != KEY_LENGTH_32) || secret == NULL) {
MS_LOG(ERROR) << "secret is invalid!";
return -1;
}
if (noise == NULL || noise_len <= 0) {
MS_LOG(ERROR) << "noise is invalid!";
return -1;
}
if (ivec == NULL || ivec_size != AES_IV_SIZE) {
MS_LOG(ERROR) << "ivec is invalid!";
return -1;
}
int size = noise_len * sizeof(int);
std::vector<unsigned char> data(size, 0);
std::vector<unsigned char> encrypt_data(size, 0);
std::vector<unsigned char> ivec(INIT_VEC_SIZE, 0);
std::vector<uint8_t> data(size, 0);
std::vector<uint8_t> encrypt_data(size, 0);
int encrypt_len = 0;
AESEncrypt encrypt(seed, seed_len, ivec.data(), INIT_VEC_SIZE, AES_CTR);
AESEncrypt encrypt(secret, secret_len, ivec, AES_IV_SIZE, AES_CTR);
if (encrypt.EncryptData(data.data(), size, encrypt_data.data(), &encrypt_len) != 0) {
MS_LOG(ERROR) << "call encryptData fail!";
MS_LOG(ERROR) << "call AES-CTR failed!";
return -1;
}

View File

@ -19,27 +19,15 @@
#include <random>
#include <vector>
#ifndef _WIN32
#include <openssl/rand.h>
#endif
#include "fl/armour/secure_protocol/encrypt.h"
namespace mindspore {
namespace armour {
#define RANDOM_LEN 8
class Random {
class Masking {
public:
explicit Random(size_t init_seed);
~Random();
// use openssl RAND_priv_bytes
static int GetRandomBytes(unsigned char *secret, int num_bytes);
static int RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len);
private:
std::default_random_engine generator;
static int GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *secret, int secret_len,
const uint8_t *ivec, int ivec_size);
};
} // namespace armour
} // namespace mindspore

View File

@ -62,6 +62,11 @@ struct RoundConfig {
struct CipherConfig {
float share_secrets_ratio = 1.0;
uint64_t cipher_time_window = 300000;
size_t exchange_keys_threshold = 0;
size_t get_keys_threshold = 0;
size_t share_secrets_threshold = 0;
size_t get_secrets_threshold = 0;
size_t client_list_threshold = 0;
size_t reconstruct_secrets_threshold = 0;
};
@ -207,8 +212,11 @@ constexpr auto kCtxClientNoises = "clients_noises";
constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";
constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares";
constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list";
constexpr auto kCtxGetSecretsClientList = "get_secrets_client_list";
constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list";
constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
constexpr auto kCtxCipherPrimer = "cipher_primer";

View File

@ -41,108 +41,102 @@ void ClientListKernel::InitKernel(size_t) {
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
std::shared_ptr<server::FBBuilder> fbb) {
bool response = false;
std::vector<string> client_list;
std::vector<string> empty_client_list;
std::string fl_id = get_clients_req->fl_id()->str();
int32_t iter_client = (size_t)get_clients_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "ClientListKernel iteration invalid. servertime is " << iter_num;
MS_LOG(ERROR) << "ClientListKernel iteration invalid. clienttime is " << iter_client;
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
uint64_t update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
}
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) { // client in client_list.
if (static_cast<uint64_t>(client_list_pb.fl_id_size()) >= update_model_client_num) {
MS_LOG(INFO) << "send clients_list succeed!";
MS_LOG(INFO) << "UpdateModel client list: ";
for (size_t i = 0; i < client_list.size(); ++i) {
MS_LOG(INFO) << " fl_id : " << client_list[i];
}
MS_LOG(INFO) << "update_model_client_num: " << update_model_client_num;
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
response = true;
} else {
MS_LOG(INFO) << "The server is not ready. update_model_client_need_num: " << update_model_client_num;
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
/*for (size_t i = 0; i < std::min(client_list.size(), size_t(2)); ++i) {
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
}
for (size_t i = client_list.size() - size_t(1); i > std::max(client_list.size() - size_t(2), size_t(0));
--i) {
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
}*/
int count_tmp = 0;
for (size_t i = 0; i < cipher_init_->get_model_num_need_; ++i) {
size_t j = 0;
for (; j < client_list.size(); ++j) {
if (("f" + std::to_string(i)) == client_list[j]) break;
}
if (j >= client_list.size()) {
count_tmp++;
MS_LOG(INFO) << " no client_list fl_id : " << i;
if (count_tmp > 3) break;
}
}
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
}
}
if (response) {
DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str());
}
} else {
MS_LOG(ERROR) << "update_model_client_num is zero.";
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_num is zero.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
}
if (!LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
MS_LOG(ERROR) << "update_model_client_threshold is not set.";
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_threshold is not set.",
empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false;
}
return response;
uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
}
if (static_cast<uint64_t>(client_list.size()) < update_model_client_needed) {
MS_LOG(INFO) << "The server is not ready. update_model_client_needed: " << update_model_client_needed;
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false;
}
if (find(client_list.begin(), client_list.end(), fl_id) == client_list.end()) { // client not in update model clients
std::string reason = "fl_id: " + fl_id + " is not in the update_model_clients";
MS_LOG(INFO) << reason;
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false;
}
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetUpdateModelClientList, fl_id);
if (!retcode_client) {
std::string reason = "update get update model clients failed";
MS_LOG(ERROR) << reason;
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false;
}
if (!DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str())) {
std::string reason = "Counting for get user list request failed. Please retry later.";
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
return true;
}
MS_LOG(INFO) << "send clients_list succeed!";
MS_LOG(INFO) << "UpdateModel client list: ";
for (size_t i = 0; i < client_list.size(); ++i) {
MS_LOG(INFO) << " fl_id : " << client_list[i];
}
MS_LOG(INFO) << "update_model_client_needed: " << update_model_client_needed;
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return true;
}
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
clock_t start_time = clock();
std::vector<string> client_list;
if (inputs.size() != 1) {
MS_LOG(ERROR) << "ClientListKernel needs 1 input,but got " << inputs.size();
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel input num not match", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ClientListKernel needs 1 output,but got " << outputs.size();
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel output num not match", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "ClientListKernel num is enough", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
void *req_data = inputs[0]->addr;
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
if (get_clients_req == nullptr || fbb == nullptr) {
MS_LOG(ERROR) << "GetClientList is nullptr or ClientListRsp builder is nullptr.";
BuildClientListRsp(fbb, schema::ResponseCode_RequestError,
"GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
DealClient(iter_num, get_clients_req, fbb);
}
}
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
std::vector<string> client_list;
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
int32_t iter_client = (size_t)get_clients_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "client list iteration number is invalid: server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
}
DealClient(iter_num, get_clients_req, fbb);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
@ -164,30 +158,22 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli
const int iteration) {
auto rsp_reason = client_list_resp_builder->CreateString(reason);
auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time);
if (clients.size() > 0) {
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
for (auto client : clients) {
auto client_fb = client_list_resp_builder->CreateString(client);
clients_vector.push_back(client_fb);
}
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_clients(clients_fb);
rsp_builder.add_iteration(iteration);
rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys);
} else {
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_iteration(iteration);
rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys);
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
for (auto client : clients) {
auto client_fb = client_list_resp_builder->CreateString(client);
clients_vector.push_back(client_fb);
MS_LOG(WARNING) << "update client list: ";
MS_LOG(WARNING) << client;
}
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_clients(clients_fb);
rsp_builder.add_iteration(iteration);
rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys);
return;
}

View File

@ -38,9 +38,36 @@ void ExchangeKeysKernel::InitKernel(size_t) {
cipher_key_ = &armour::CipherKeys::GetInstance();
}
bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for exchangeKey is enough. Please retry later.";
cipher_key_->BuildExchangeKeysRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
MS_LOG(WARNING) << reason;
return true;
}
return false;
}
bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestExchangeKeys *exchange_keys_req,
const int iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(exchange_keys_req, false);
if (!DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str())) {
std::string reason = "Counting for exchange kernel request failed. Please retry later.";
cipher_key_->BuildExchangeKeysRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
MS_LOG(INFO) << "Launching ExchangeKey kernel.";
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
@ -48,46 +75,49 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 input,but got " << inputs.size();
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel input num not match",
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
void *req_data = inputs[0]->addr;
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
if (ReachThresholdForExchangeKeys(fbb, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
int32_t iter_client = (size_t)exchange_keys_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "ExchangeKeys iteration number is invalid: server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 output,but got " << outputs.size();
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel output num not match",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ExchangeKeysKernel is enough.";
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ExchangeKeysKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
void *req_data = inputs[0]->addr;
const schema::RequestExchangeKeys *exchange_keys_req =
flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
int32_t iter_client = (size_t)exchange_keys_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "ExchangeKeysKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
response =
cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
if (response) {
DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str());
}
}
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
if (!response) {
MS_LOG(WARNING) << "update exchange keys is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForExchangeKeys(fbb, exchange_keys_req, iter_num)) {
MS_LOG(ERROR) << "count for exchange keys failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
if (!response) {
MS_LOG(INFO) << "ExchangeKeysKernel response is false.";
}
return true;
}

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#include <vector>
#include <string>
#include <memory>
#include "fl/server/common.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
@ -41,6 +43,9 @@ class ExchangeKeysKernel : public RoundKernel {
Executor *executor_;
size_t iteration_time_window_;
armour::CipherKeys *cipher_key_;
bool ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num);
bool CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestExchangeKeys *exchange_keys_req,
const int iter_num);
};
} // namespace kernel
} // namespace server

View File

@ -37,9 +37,23 @@ void GetKeysKernel::InitKernel(size_t) {
cipher_key_ = &armour::CipherKeys::GetInstance();
}
bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
const int iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(get_keys_req, false);
if (!DistributedCountService::GetInstance().Count(name_, get_keys_req->fl_id()->str())) {
std::string reason = "Counting for getkeys kernel request failed. Please retry later.";
cipher_key_->BuildGetKeysRsp(
fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), false);
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
MS_LOG(INFO) << "Launching GetKeys kernel.";
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
@ -47,44 +61,48 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "GetKeysKernel needs 1 input,but got " << inputs.size();
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "GetKeysKernel needs 1 output,but got " << outputs.size();
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
} else {
void *req_data = inputs[0]->addr;
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
int32_t iter_client = (size_t)get_exchange_keys_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
} else {
response =
cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
if (response) {
DistributedCountService::GetInstance().Count(name_, get_exchange_keys_req->fl_id()->str());
}
}
}
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
}
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
int32_t iter_client = (size_t)get_exchange_keys_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
if (!response) {
MS_LOG(WARNING) << "get public keys is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForGetKeys(fbb, get_exchange_keys_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
if (!response) {
MS_LOG(INFO) << "GetKeysKernel response is false.";
}
return true;
}

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
#include <vector>
#include <string>
#include <memory>
#include "fl/server/common.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
@ -41,6 +43,8 @@ class GetKeysKernel : public RoundKernel {
Executor *executor_;
size_t iteration_time_window_;
armour::CipherKeys *cipher_key_;
bool CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
const int iter_num);
};
} // namespace kernel
} // namespace server

View File

@ -39,54 +39,72 @@ void GetSecretsKernel::InitKernel(size_t) {
cipher_share_ = &armour::CipherShares::GetInstance();
}
bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
const schema::GetShareSecrets *get_secrets_req, const int iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false);
if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) {
std::string reason = "Counting for get secrets kernel request failed. Please retry later.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), nullptr);
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
bool response = false;
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "GetSecretsKernel needs 1 input,but got " << inputs.size();
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "GetSecretsKernel needs 1 output,but got " << outputs.size();
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
} else {
void *req_data = inputs[0]->addr;
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
int32_t iter_client = (size_t)get_secrets_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
} else {
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
if (response) {
DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str());
}
}
}
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
int32_t iter_client = (size_t)get_secrets_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
}
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
if (!response) {
MS_LOG(WARNING) << "get secret shares is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForGetSecrets(fbb, get_secrets_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
if (!response) {
MS_LOG(INFO) << "GetSecretsKernel response is false.";
}
return true;
}

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#include <vector>
#include <memory>
#include "fl/server/common.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
@ -41,6 +42,8 @@ class GetSecretsKernel : public RoundKernel {
Executor *executor_;
size_t iteration_time_window_;
armour::CipherShares *cipher_share_;
bool CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::GetShareSecrets *get_secrets_req,
const int iter_num);
};
} // namespace kernel
} // namespace server

View File

@ -52,7 +52,6 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
// MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
@ -61,71 +60,58 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
"ReconstructSecretsKernel input num not match.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 output, but got " << outputs.size();
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
"ReconstructSecretsKernel output num not match.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
// get client list from memory server.
std::vector<string> update_model_clients;
const PBMetadata update_model_clients_pb_out =
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
for (int i = 0; i < update_model_clients_pb.fl_id_size(); ++i) {
update_model_clients.push_back(update_model_clients_pb.fl_id(i));
}
const schema::SendReconstructSecret *reconstruct_secret_req =
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
std::string fl_id = reconstruct_secret_req->fl_id()->str();
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
// client in get update model client list.
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED,
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
} else {
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
void *req_data = inputs[0]->addr;
const schema::SendReconstructSecret *reconstruct_secret_req =
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
// get client list from memory server.
std::vector<string> client_list;
uint64_t update_model_client_num = 0;
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
} else {
MS_LOG(ERROR) << "update_model_client_num is zero.";
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
"update_model_client_num is zero.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const PBMetadata client_list_pb_out =
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
int client_list_actual_size = client_list_pb.fl_id_size();
if (client_list_actual_size < 0) {
client_list_actual_size = 0;
}
if (static_cast<uint64_t>(client_list_actual_size) < update_model_client_num) {
MS_LOG(INFO) << "ReconstructSecretsKernel : client list is not ready " << inputs.size();
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SucNotReady,
"ReconstructSecretsKernel : client list is not ready", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
}
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
reconstruct_secret_req, fbb, client_list);
if (response) {
// MS_LOG(INFO) << "start ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
// MS_LOG(INFO) << "end ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
reconstruct_secret_req, fbb, update_model_clients);
if (response) {
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
}
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough.";
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
@ -142,7 +128,6 @@ void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
MS_LOG(INFO) << "start unmask";
while (!Executor::GetInstance().Unmask()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));

View File

@ -36,58 +36,75 @@ void ShareSecretsKernel::InitKernel(size_t) {
cipher_share_ = &armour::CipherShares::GetInstance();
}
bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestShareSecrets *share_secrets_req,
const int iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(share_secrets_req, false);
if (!DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str())) {
std::string reason = "Counting for share secret kernel request failed. Please retry later.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
bool response = false;
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 input,but got " << inputs.size();
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError, "ShareSecretsKernel input num not match",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 output,but got " << outputs.size();
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError,
"ShareSecretsKernel output num not match",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ShareSecretsKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
void *req_data = inputs[0]->addr;
const schema::RequestShareSecrets *share_secrets_req =
flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
size_t iter_client = (size_t)share_secrets_req->iteration();
if (iter_num != iter_client) {
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
response =
cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
if (response) {
DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str());
}
}
}
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ShareSecretsKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
size_t iter_client = (size_t)share_secrets_req->iteration();
if (iter_num != iter_client) {
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
response = cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
if (!response) {
MS_LOG(WARNING) << "update secret shares is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForShareSecrets(fbb, share_secrets_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
if (!response) {
MS_LOG(INFO) << "share_secrets_kernel response is false.";
}
return true;
}

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#include <vector>
#include <string>
#include <memory>
#include "fl/server/common.h"
#include "fl/server/executor.h"
#include "fl/server/kernel/round/round_kernel.h"
@ -41,6 +43,8 @@ class ShareSecretsKernel : public RoundKernel {
Executor *executor_;
size_t iteration_time_window_;
armour::CipherShares *cipher_share_;
bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req,
const int iter_num);
};
} // namespace kernel
} // namespace server

View File

@ -157,14 +157,31 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
std::string update_model_fl_id = update_model_req->fl_id()->str();
MS_LOG(INFO) << "Update model for fl id " << update_model_fl_id;
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id;
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
}
} else {
std::vector<std::string> get_secrets_clients;
#ifdef ENABLE_ARMOUR
mindspore::armour::CipherMetaStorage cipher_meta_storage;
cipher_meta_storage.GetClientListFromServer(fl::server::kCtxGetSecretsClientList, &get_secrets_clients);
#endif
if (find(get_secrets_clients.begin(), get_secrets_clients.end(), update_model_fl_id) ==
get_secrets_clients.end()) { // the client not in get_secrets_clients
std::string reason = "fl_id: " + update_model_fl_id + " is not in get_secrets_clients. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
}
}
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();

View File

@ -25,6 +25,9 @@
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
#include "fl/server/executor.h"
#ifdef ENABLE_ARMOUR
#include "fl/armour/cipher/cipher_meta_storage.h"
#endif
namespace mindspore {
namespace fl {

View File

@ -213,53 +213,24 @@ void Server::InitIteration() {
#ifdef ENABLE_ARMOUR
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) {
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold;
cipher_exchange_keys_cnt_ = cipher_config_.exchange_keys_threshold;
cipher_get_keys_cnt_ = cipher_config_.get_keys_threshold;
cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
cipher_time_window_ = cipher_config_.cipher_time_window;
MS_LOG(INFO) << "Initializing cipher:";
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
MS_LOG(INFO) << " cipher_exchange_keys_cnt_: " << cipher_exchange_keys_cnt_
<< " cipher_get_keys_cnt_: " << cipher_get_keys_cnt_
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
<< " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
<< " cipher_time_window_: " << cipher_time_window_
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
std::shared_ptr<Round> exchange_keys_round =
std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
MS_EXCEPTION_IF_NULL(exchange_keys_round);
iteration_->AddRound(exchange_keys_round);
std::shared_ptr<Round> get_keys_round =
std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
MS_EXCEPTION_IF_NULL(get_keys_round);
iteration_->AddRound(get_keys_round);
std::shared_ptr<Round> share_secrets_round =
std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
MS_EXCEPTION_IF_NULL(share_secrets_round);
iteration_->AddRound(share_secrets_round);
std::shared_ptr<Round> get_secrets_round =
std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
MS_EXCEPTION_IF_NULL(get_secrets_round);
iteration_->AddRound(get_secrets_round);
std::shared_ptr<Round> get_clientlist_round =
std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_);
MS_EXCEPTION_IF_NULL(get_clientlist_round);
iteration_->AddRound(get_clientlist_round);
std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>(
"reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_);
MS_EXCEPTION_IF_NULL(reconstruct_secrets_round);
iteration_->AddRound(reconstruct_secrets_round);
MS_LOG(INFO) << "Cipher rounds has been added.";
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
<< " cipher_time_window_: " << cipher_time_window_;
}
#endif
@ -314,8 +285,8 @@ void Server::InitCipher() {
param.dp_eps = dp_eps;
param.dp_norm_clip = dp_norm_clip;
param.encrypt_type = encrypt_type;
cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_,
cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
cipher_reconstruct_secrets_up_cnt_);
#endif
}

View File

@ -80,7 +80,7 @@ class Server {
worker_num_(0),
fl_server_port_(0),
cipher_initial_client_cnt_(0),
cipher_exchange_secrets_cnt_(0),
cipher_exchange_keys_cnt_(0),
cipher_share_secrets_cnt_(0),
cipher_get_clientlist_cnt_(0),
cipher_reconstruct_secrets_up_cnt_(0),
@ -197,8 +197,10 @@ class Server {
uint32_t worker_num_;
uint16_t fl_server_port_;
size_t cipher_initial_client_cnt_;
size_t cipher_exchange_secrets_cnt_;
size_t cipher_exchange_keys_cnt_;
size_t cipher_get_keys_cnt_;
size_t cipher_share_secrets_cnt_;
size_t cipher_get_secrets_cnt_;
size_t cipher_get_clientlist_cnt_;
size_t cipher_reconstruct_secrets_up_cnt_;
size_t cipher_reconstruct_secrets_down_cnt_;

View File

@ -856,9 +856,33 @@ bool StartServerAction(const ResourcePtr &res) {
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold();
size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold() + 1;
fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshold};
size_t exchange_keys_threshold =
std::max(static_cast<size_t>(std::ceil(start_fl_job_threshold * share_secrets_ratio)), update_model_threshold);
size_t get_keys_threshold =
std::max(static_cast<size_t>(std::ceil(exchange_keys_threshold * share_secrets_ratio)), update_model_threshold);
size_t share_secrets_threshold =
std::max(static_cast<size_t>(std::ceil(get_keys_threshold * share_secrets_ratio)), update_model_threshold);
size_t get_secrets_threshold =
std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold);
size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)),
reconstruct_secrets_threshold);
#ifdef ENABLE_ARMOUR
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) {
MS_LOG(INFO) << "Add secure aggregation rounds.";
rounds_config.push_back({"exchangeKeys", true, cipher_time_window, true, exchange_keys_threshold});
rounds_config.push_back({"getKeys", true, cipher_time_window, true, get_keys_threshold});
rounds_config.push_back({"shareSecrets", true, cipher_time_window, true, share_secrets_threshold});
rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold});
rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
}
#endif
fl::server::CipherConfig cipher_config = {
share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,
share_secrets_threshold, get_secrets_threshold, client_list_threshold, reconstruct_secrets_threshold};
size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {

View File

@ -126,6 +126,13 @@ message SharesPb {
message KeysPb {
repeated bytes key = 1;
string timestamp = 2;
int32 iter_num = 3;
bytes ind_iv = 4;
bytes pw_iv = 5;
bytes pw_salt = 6;
bytes signature = 7;
repeated string certificate_chain = 8;
}
message Prime {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,12 +16,18 @@
package com.mindspore.flclient;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.I_VEC_LEN;
import static com.mindspore.flclient.LocalFLParameter.SALT_SIZE;
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.cipher.AESEncrypt;
import com.mindspore.flclient.cipher.BaseUtil;
import com.mindspore.flclient.cipher.ClientListReq;
import com.mindspore.flclient.cipher.KEYAgreement;
import com.mindspore.flclient.cipher.Random;
import com.mindspore.flclient.cipher.Masking;
import com.mindspore.flclient.cipher.ReconstructSecretReq;
import com.mindspore.flclient.cipher.ShareSecrets;
import com.mindspore.flclient.cipher.struct.ClientPublicKey;
@ -29,6 +35,7 @@ import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import com.mindspore.flclient.cipher.struct.EncryptShare;
import com.mindspore.flclient.cipher.struct.NewArray;
import com.mindspore.flclient.cipher.struct.ShareSecret;
import mindspore.schema.ClientShare;
import mindspore.schema.GetExchangeKeys;
import mindspore.schema.GetShareSecrets;
@ -40,22 +47,22 @@ import mindspore.schema.ResponseShareSecrets;
import mindspore.schema.ReturnExchangeKeys;
import mindspore.schema.ReturnShareSecrets;
import java.io.UnsupportedEncodingException;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.time.LocalDateTime;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE;
/**
* A class used for secure aggregation
*
* @since 2021-8-27
*/
public class CipherClient {
private static final Logger LOGGER = Logger.getLogger(CipherClient.class.toString());
private FLCommunication flCommunication;
@ -63,129 +70,217 @@ public class CipherClient {
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private final int iteration;
private int featureSize;
private int t;
private int minShareNum;
private List<byte[]> cKey = new ArrayList<>();
private List<byte[]> sKey = new ArrayList<>();
private byte[] bu;
private byte[] individualIv = new byte[I_VEC_LEN];
private byte[] pwIVec = new byte[I_VEC_LEN];
private byte[] pwSalt = new byte[SALT_SIZE];
private String nextRequestTime;
private Map<String, ClientPublicKey> clientPublicKeyList = new HashMap<String, ClientPublicKey>();
private Map<String, byte[]> sUVKeys = new HashMap<String, byte[]>();
private Map<String, byte[]> cUVKeys = new HashMap<String, byte[]>();
private List<EncryptShare> clientShareList = new ArrayList<>();
private List<EncryptShare> returnShareList = new ArrayList<>();
private float[] featureMask;
private List<String> u1ClientList = new ArrayList<>();
private List<String> u2UClientList = new ArrayList<>();
private List<String> u3ClientList = new ArrayList<>();
private List<DecryptShareSecrets> decryptShareSecretsList = new ArrayList<>();
private byte[] prime;
private KEYAgreement keyAgreement = new KEYAgreement();
private Random random = new Random();
private Masking masking = new Masking();
private ClientListReq clientListReq = new ClientListReq();
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
private int retCode;
/**
* Construct function of cipherClient
*
* @param iter iteration number
* @param minSecretNum minimum secret shares number used for reconstruct secret
* @param prime prime value
* @param featureSize featureSize of network
*/
public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) {
flCommunication = FLCommunication.getInstance();
this.iteration = iter;
this.featureSize = featureSize;
this.t = minSecretNum;
this.minShareNum = minSecretNum;
this.prime = prime;
this.featureMask = new float[this.featureSize];
}
/**
* Set next request time
*
* @param nextRequestTime next request timestamp
*/
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
public void setBU(byte[] bu) {
this.bu = bu;
}
public void setClientShareList(List<EncryptShare> clientShareList) {
/**
* Set client share list
*
* @param clientShareList client share list
*/
private void setClientShareList(List<EncryptShare> clientShareList) {
this.clientShareList.clear();
this.clientShareList = clientShareList;
}
/**
* get next request time
*
* @return next request time
*/
public String getNextRequestTime() {
return nextRequestTime;
}
/**
* get retCode
*
* @return retCode
*/
public int getRetCode() {
return retCode;
}
public void genDHKeyPairs() {
private FLClientStatus genDHKeyPairs() {
byte[] csk = keyAgreement.generatePrivateKey();
byte[] cpk = keyAgreement.generatePublicKey(csk);
if (cpk == null || cpk.length == 0) {
LOGGER.severe(Common.addTag("[genDHKeyPairs] the return byte[] <cpk> is null, please check!"));
return FLClientStatus.FAILED;
}
byte[] ssk = keyAgreement.generatePrivateKey();
byte[] spk = keyAgreement.generatePublicKey(ssk);
if (spk == null || spk.length == 0) {
LOGGER.severe(Common.addTag("[genDHKeyPairs] the return byte[] <spk> is null, please check!"));
return FLClientStatus.FAILED;
}
this.cKey.clear();
this.sKey.clear();
this.cKey.add(cpk);
this.cKey.add(csk);
this.sKey.add(spk);
this.sKey.add(ssk);
return FLClientStatus.SUCCESS;
}
public void genIndividualSecret() {
private FLClientStatus genIndividualSecret() {
byte[] key = new byte[SEED_SIZE];
random.getRandomBytes(key);
setBU(key);
int tag = masking.getRandomBytes(key);
if (tag == -1) {
LOGGER.severe(Common.addTag("[genIndividualSecret] the return value is -1, please check!"));
return FLClientStatus.FAILED;
}
this.bu = key;
return FLClientStatus.SUCCESS;
}
public List<ShareSecret> genSecretShares(byte[] secret) throws UnsupportedEncodingException {
List<ShareSecret> shareSecretList = new ArrayList<>();
private List<ShareSecret> genSecretShares(byte[] secret) {
if (secret == null || secret.length == 0) {
LOGGER.severe(Common.addTag("[genSecretShares] the input argument <secret> is null"));
return new ArrayList<>();
}
int size = u1ClientList.size();
ShareSecrets shamir = new ShareSecrets(t, size - 1);
ShareSecrets.SecretShare[] shares = shamir.split(secret, prime);
int j = 0;
for (int i = 0; i < size; i++) {
String vFlID = u1ClientList.get(i);
if (size <= 1) {
LOGGER.severe(Common.addTag("[genSecretShares] the size of u1ClientList is not valid: <= 1, it should be " +
"> 1"));
return new ArrayList<>();
}
ShareSecrets shamir = new ShareSecrets(minShareNum, size - 1);
ShareSecrets.SecretShares[] shares = shamir.split(secret, prime);
if (shares == null || shares.length == 0) {
LOGGER.severe(Common.addTag("[genSecretShares] the return ShareSecrets.SecretShare[] is null, please " +
"check!"));
return new ArrayList<>();
}
int shareIndex = 0;
List<ShareSecret> shareSecretList = new ArrayList<>();
for (String vFlID : u1ClientList) {
if (localFLParameter.getFlID().equals(vFlID)) {
continue;
} else {
ShareSecret shareSecret = new ShareSecret();
NewArray<byte[]> array = new NewArray<>();
int index = shares[j].getNum();
BigInteger intShare = shares[j].getShare();
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
array.setSize(share.length);
array.setArray(share);
shareSecret.setFlID(vFlID);
shareSecret.setShare(array);
shareSecret.setIndex(index);
shareSecretList.add(shareSecret);
j += 1;
}
if (shareIndex >= shares.length) {
LOGGER.severe(Common.addTag("[genSecretShares] the shareIndex is out of range in array <shares>, " +
"please check!"));
return new ArrayList<>();
}
int index = shares[shareIndex].getNumber();
BigInteger intShare = shares[shareIndex].getShares();
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
NewArray<byte[]> array = new NewArray<>();
array.setSize(share.length);
array.setArray(share);
ShareSecret shareSecret = new ShareSecret();
shareSecret.setFlID(vFlID);
shareSecret.setShare(array);
shareSecret.setIndex(index);
shareSecretList.add(shareSecret);
shareIndex += 1;
}
return shareSecretList;
}
public void genEncryptExchangedKeys() throws InvalidKeySpecException, NoSuchAlgorithmException {
private FLClientStatus genEncryptExchangedKeys() {
cUVKeys.clear();
for (String key : clientPublicKeyList.keySet()) {
ClientPublicKey curPublicKey = clientPublicKeyList.get(key);
String vFlID = curPublicKey.getFlID();
if (localFLParameter.getFlID().equals(vFlID)) {
continue;
} else {
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
byte[] salt = new byte[0];
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
cUVKeys.put(vFlID, secret);
}
if (cKey.size() < 2) {
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the size of cKey is not valid: < 2, it should " +
"be >= 2, please check!"));
return FLClientStatus.FAILED;
}
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
if (secret1 == null || secret1.length == 0) {
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the returned secret1 is null, please check!"));
return FLClientStatus.FAILED;
}
byte[] salt = new byte[0];
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
if (secret == null || secret.length == 0) {
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the returned secret is null, please check!"));
return FLClientStatus.FAILED;
}
cUVKeys.put(vFlID, secret);
}
return FLClientStatus.SUCCESS;
}
public void encryptShares() throws Exception {
LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for RequestShareSecrets **************"));
List<EncryptShare> encryptShareList = new ArrayList<>();
private FLClientStatus encryptShares() {
LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for " +
"RequestShareSecrets **************"));
// connect sSkUv, bUV, sIndex, indexB and then Encrypt them
if (sKey.size() < 2) {
LOGGER.severe(Common.addTag("[encryptShares] the size of sKey is not valid: < 2, it should be >= 2, " +
"please check!"));
return FLClientStatus.FAILED;
}
List<ShareSecret> sSkUv = genSecretShares(sKey.get(1));
if (sSkUv.isEmpty()) {
LOGGER.severe(Common.addTag("[encryptShares] the returned List<ShareSecret> sSkUv is empty, please " +
"check!"));
return FLClientStatus.FAILED;
}
List<ShareSecret> bUV = genSecretShares(bu);
if (sSkUv.isEmpty()) {
LOGGER.severe(Common.addTag("[encryptShares] the returned List<ShareSecret> bUV is empty, please check!"));
return FLClientStatus.FAILED;
}
if (sSkUv.size() != bUV.size()) {
LOGGER.severe(Common.addTag("[encryptShares] the sSkUv.size() should be equal to bUV.size(), please " +
"check!"));
return FLClientStatus.FAILED;
}
List<EncryptShare> encryptShareList = new ArrayList<>();
for (int i = 0; i < bUV.size(); i++) {
EncryptShare encryptShare = new EncryptShare();
NewArray<byte[]> array = new NewArray<>();
String vFlID = bUV.get(i).getFlID();
byte[] sShare = sSkUv.get(i).getShare().getArray();
byte[] bShare = bUV.get(i).getShare().getArray();
byte[] sIndex = BaseUtil.integer2byteArray(sSkUv.get(i).getIndex());
@ -200,53 +295,101 @@ public class CipherClient {
System.arraycopy(sShare, 0, allSecret, 4 + sIndex.length + bIndex.length, sShare.length);
System.arraycopy(bShare, 0, allSecret, 4 + sIndex.length + bIndex.length + sShare.length, bShare.length);
// encrypt:
byte[] iVecIn = new byte[IVEC_LEN];
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), iVecIn, "CBC");
String vFlID = bUV.get(i).getFlID();
if (!cUVKeys.containsKey(vFlID)) {
LOGGER.severe(Common.addTag("[encryptShares] the key " + vFlID + " is not in map cUVKeys, please " +
"check!"));
return FLClientStatus.FAILED;
}
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), "CBC");
byte[] encryptData = aesEncrypt.encrypt(cUVKeys.get(vFlID), allSecret);
if (encryptData == null || encryptData.length == 0) {
LOGGER.severe(Common.addTag("[encryptShares] the return byte[] is null, please check!"));
return FLClientStatus.FAILED;
}
NewArray<byte[]> array = new NewArray<>();
array.setSize(encryptData.length);
array.setArray(encryptData);
EncryptShare encryptShare = new EncryptShare();
encryptShare.setFlID(vFlID);
encryptShare.setShare(array);
encryptShareList.add(encryptShare);
}
setClientShareList(encryptShareList);
return FLClientStatus.SUCCESS;
}
public float[] doubleMaskingWeight() throws Exception {
int size = u2UClientList.size();
/**
* get masked weight of secure aggregation
*
* @return masked weight
*/
public float[] doubleMaskingWeight() {
List<Float> noiseBu = new ArrayList<>();
random.randomAESCTR(noiseBu, featureSize, bu);
int tag = masking.getMasking(noiseBu, featureSize, bu, individualIv);
if (tag == -1) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the return value is -1, please check!"));
return new float[0];
}
float[] mask = new float[featureSize];
for (int i = 0; i < size; i++) {
String vFlID = u2UClientList.get(i);
for (String vFlID : u2UClientList) {
if (!clientPublicKeyList.containsKey(vFlID)) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the key " + vFlID + " is not in map " +
"clientPublicKeyList, please check!"));
return new float[0];
}
ClientPublicKey curPublicKey = clientPublicKeyList.get(vFlID);
if (localFLParameter.getFlID().equals(vFlID)) {
continue;
}
byte[] salt;
byte[] iVec;
if (vFlID.compareTo(localFLParameter.getFlID()) < 0) {
salt = curPublicKey.getPwSalt().getArray();
iVec = curPublicKey.getPwIv().getArray();
} else {
byte[] salt = new byte[0];
byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray());
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
sUVKeys.put(vFlID, secret);
List<Float> noiseSuv = new ArrayList<>();
random.randomAESCTR(noiseSuv, featureSize, secret);
int sign;
if (localFLParameter.getFlID().compareTo(vFlID) > 0) {
sign = 1;
} else {
sign = -1;
}
for (int j = 0; j < noiseSuv.size(); j++) {
mask[j] = mask[j] + sign * noiseSuv.get(j);
}
salt = this.pwSalt;
iVec = this.pwIVec;
}
if (sKey.size() < 2) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the size of sKey is not valid: < 2, it should be " +
">= 2, please check!"));
return new float[0];
}
byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray());
if (secret1 == null || secret1.length == 0) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the returned secret1 is null, please check!"));
return new float[0];
}
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
if (secret == null || secret.length == 0) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the returned secret is null, please check!"));
return new float[0];
}
sUVKeys.put(vFlID, secret);
List<Float> noiseSuv = new ArrayList<>();
tag = masking.getMasking(noiseSuv, featureSize, secret, iVec);
if (tag == -1) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the return value is -1, please check!"));
return new float[0];
}
int sign;
if (localFLParameter.getFlID().compareTo(vFlID) > 0) {
sign = 1;
} else {
sign = -1;
}
for (int maskIndex = 0; maskIndex < noiseSuv.size(); maskIndex++) {
mask[maskIndex] = mask[maskIndex] + sign * noiseSuv.get(maskIndex);
}
}
for (int j = 0; j < noiseBu.size(); j++) {
mask[j] = mask[j] + noiseBu.get(j);
for (int maskIndex = 0; maskIndex < noiseBu.size(); maskIndex++) {
mask[maskIndex] = mask[maskIndex] + noiseBu.get(maskIndex);
}
return mask;
}
public NewArray<byte[]> byteToArray(ByteBuffer buf, int size) {
private NewArray<byte[]> byteToArray(ByteBuffer buf, int size) {
NewArray<byte[]> newArray = new NewArray<>();
newArray.setSize(size);
byte[] array = new byte[size];
@ -258,40 +401,80 @@ public class CipherClient {
return newArray;
}
public FLClientStatus requestExchangeKeys() {
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "=============="));
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
genDHKeyPairs();
private FLClientStatus requestExchangeKeys() {
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() +
"=============="));
FLClientStatus status = genDHKeyPairs();
if (status == FLClientStatus.FAILED) {
LOGGER.severe(Common.addTag("[requestExchangeKeys] the return status is FAILED, please check!"));
return FLClientStatus.FAILED;
}
if (cKey.size() <= 0 || sKey.size() <= 0) {
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of cKey or sKey is not valid: <=0."));
return FLClientStatus.FAILED;
}
if (cKey.size() < 2) {
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of cKey is not valid: < 2, it should be >= 2," +
" please check!"));
return FLClientStatus.FAILED;
}
if (sKey.size() < 2) {
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of sKey is not valid: < 2, it should be >= 2," +
" please check!"));
return FLClientStatus.FAILED;
}
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
byte[] cPK = cKey.get(0);
byte[] sPK = sKey.get(0);
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
String dateTime = LocalDateTime.now().toString();
byte[] indIv = new byte[I_VEC_LEN];
byte[] pwIv = new byte[I_VEC_LEN];
byte[] thisPwSalt = new byte[SALT_SIZE];
SecureRandom secureRandom = Common.getSecureRandom();
secureRandom.nextBytes(indIv);
secureRandom.nextBytes(pwIv);
secureRandom.nextBytes(thisPwSalt);
this.individualIv = indIv;
this.pwIVec = pwIv;
this.pwSalt = thisPwSalt;
int indIvFbs = RequestExchangeKeys.createIndIvVector(fbBuilder, indIv);
int pwIvFbs = RequestExchangeKeys.createPwIvVector(fbBuilder, pwIv);
int pwSaltFbs = RequestExchangeKeys.createPwSaltVector(fbBuilder, thisPwSalt);
int id = fbBuilder.createString(localFLParameter.getFlID());
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime);
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time);
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time
, indIvFbs, pwIvFbs, pwSaltFbs);
fbBuilder.finish(exchangeKeysRoot);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try {
byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[requestExchangeKeys] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[requestExchangeKeys] the server is not ready now, need wait some time and" +
" " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer);
FLClientStatus status = judgeRequestExchangeKeys(responseExchangeKeys);
return status;
} catch (Exception e) {
e.printStackTrace();
return judgeRequestExchangeKeys(responseExchangeKeys);
} catch (IOException ex) {
LOGGER.severe(Common.addTag("[requestExchangeKeys] catch IOException: " + ex.getMessage()));
return FLClientStatus.FAILED;
}
}
public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
private FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -303,7 +486,8 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success"));
return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request startFLJob again"));
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
@ -311,39 +495,43 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys is invalid: " + retCode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys " +
"is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
public FLClientStatus getExchangeKeys() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
private FLClientStatus getExchangeKeys() {
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime);
int getExchangeKeysRoot = GetExchangeKeys.createGetExchangeKeys(fbBuilder, id, iteration, time);
fbBuilder.finish(getExchangeKeysRoot);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try {
byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getExchangeKeys] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getExchangeKeys] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer);
FLClientStatus status = judgeGetExchangeKeys(returnExchangeKeys);
return status;
} catch (Exception e) {
e.printStackTrace();
return judgeGetExchangeKeys(returnExchangeKeys);
} catch (IOException ex) {
LOGGER.severe(Common.addTag("[getExchangeKeys] catch IOException: " + ex.getMessage()));
return FLClientStatus.FAILED;
}
}
public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
private FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -363,17 +551,25 @@ public class CipherClient {
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
ByteBuffer bufPwIv = bufData.remotePublickeys(i).pwIvAsByteBuffer();
int sizePwIv = bufData.remotePublickeys(i).pwIvLength();
ByteBuffer bufPwSalt = bufData.remotePublickeys(i).pwSaltAsByteBuffer();
int sizePwSalt = bufData.remotePublickeys(i).pwSaltLength();
publicKey.setCPK(byteToArray(bufCpk, sizeCpk));
publicKey.setSPK(byteToArray(bufSpk, sizeSpk));
publicKey.setPwIv(byteToArray(bufPwIv, sizePwIv));
publicKey.setPwSalt(byteToArray(bufPwSalt, sizePwSalt));
clientPublicKeyList.put(srcFlId, publicKey);
u1ClientList.add(srcFlId);
}
return FLClientStatus.SUCCESS;
case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetExchangeKeys again!"));
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
"GetExchangeKeys again!"));
return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request startFLJob again"));
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
@ -381,32 +577,48 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is invalid: " + retCode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is" +
" invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
public FLClientStatus requestShareSecrets() throws Exception {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
genIndividualSecret();
genEncryptExchangedKeys();
encryptShares();
private FLClientStatus requestShareSecrets() {
FLClientStatus status = genIndividualSecret();
if (status == FLClientStatus.FAILED) {
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from genIndividualSecret" +
"(), please check!"));
return FLClientStatus.FAILED;
}
status = genEncryptExchangedKeys();
if (status == FLClientStatus.FAILED) {
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from " +
"genEncryptExchangedKeys(), please check!"));
return FLClientStatus.FAILED;
}
status = encryptShares();
if (status == FLClientStatus.FAILED) {
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from encryptShares(), " +
"please check!"));
return FLClientStatus.FAILED;
}
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime);
int clientShareSize = clientShareList.size();
if (clientShareSize <= 0) {
LOGGER.warning(Common.addTag("[PairWiseMask] encrypt shares is not ready now!"));
Common.sleep(SLEEP_TIME);
FLClientStatus status = requestShareSecrets();
return status;
return requestShareSecrets();
} else {
int[] add = new int[clientShareSize];
for (int i = 0; i < clientShareSize; i++) {
int flID = fbBuilder.createString(clientShareList.get(i).getFlID());
int shareSecretFbs = ClientShare.createShareVector(fbBuilder, clientShareList.get(i).getShare().getArray());
int shareSecretFbs = ClientShare.createShareVector(fbBuilder,
clientShareList.get(i).getShare().getArray());
ClientShare.startClientShare(fbBuilder);
ClientShare.addFlId(fbBuilder, flID);
ClientShare.addShare(fbBuilder, shareSecretFbs);
@ -414,29 +626,33 @@ public class CipherClient {
add[i] = clientShareRoot;
}
int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add);
int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id, encryptedSharesFbs, iteration, time);
int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id,
encryptedSharesFbs, iteration, time);
fbBuilder.finish(requestShareSecretsRoot);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try {
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[requestShareSecrets] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[requestShareSecrets] the server is not ready now, need wait some time" +
" " +
"and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer);
FLClientStatus status = judgeRequestShareSecrets(responseShareSecrets);
return status;
} catch (Exception e) {
e.printStackTrace();
return judgeRequestShareSecrets(responseShareSecrets);
} catch (IOException ex) {
LOGGER.severe(Common.addTag("[requestShareSecrets] catch IOException: " + ex.getMessage()));
return FLClientStatus.FAILED;
}
}
}
public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
private FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -448,7 +664,8 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success"));
return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request startFLJob again"));
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
@ -456,39 +673,43 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets is invalid: " + retCode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets " +
"is invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
public FLClientStatus getShareSecrets() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
private FLClientStatus getShareSecrets() {
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime);
int getShareSecrets = GetShareSecrets.createGetShareSecrets(fbBuilder, id, iteration, time);
fbBuilder.finish(getShareSecrets);
byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try {
byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getShareSecrets] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getShareSecrets] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer);
FLClientStatus status = judgeGetShareSecrets(returnShareSecrets);
return status;
} catch (Exception e) {
e.printStackTrace();
return judgeGetShareSecrets(returnShareSecrets);
} catch (IOException ex) {
LOGGER.severe(Common.addTag("[getShareSecrets] catch IOException: " + ex.getMessage()));
return FLClientStatus.FAILED;
}
}
public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
private FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -503,20 +724,26 @@ public class CipherClient {
int length = bufData.encryptedSharesLength();
for (int i = 0; i < length; i++) {
EncryptShare shareSecret = new EncryptShare();
shareSecret.setFlID(bufData.encryptedShares(i).flId());
ByteBuffer bufShare = bufData.encryptedShares(i).shareAsByteBuffer();
int sizeShare = bufData.encryptedShares(i).shareLength();
ClientShare clientShare = bufData.encryptedShares(i);
if (clientShare == null) {
LOGGER.severe(Common.addTag("[PairWiseMask] the clientShare returned from server is null"));
return FLClientStatus.FAILED;
}
shareSecret.setFlID(clientShare.flId());
ByteBuffer bufShare = clientShare.shareAsByteBuffer();
int sizeShare = clientShare.shareLength();
shareSecret.setShare(byteToArray(bufShare, sizeShare));
returnShareList.add(shareSecret);
u2UClientList.add(bufData.encryptedShares(i).flId());
u2UClientList.add(clientShare.flId());
}
return FLClientStatus.SUCCESS;
case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetShareSecrets again!"));
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
"GetShareSecrets again!"));
return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request startFLJob again"));
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
@ -524,15 +751,22 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is invalid: " + retCode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is" +
" invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
/**
* exchangeKeys round of secure aggregation
*
* @return round execution result
*/
public FLClientStatus exchangeKeys() {
LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys ======================"));
FLClientStatus curStatus;
LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys " +
"======================"));
// RequestExchangeKeys
FLClientStatus curStatus;
curStatus = requestExchangeKeys();
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
@ -551,8 +785,14 @@ public class CipherClient {
return curStatus;
}
public FLClientStatus shareSecrets() throws Exception {
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets ======================")));
/**
* shareSecrets round of secure aggregation
*
* @return round execution result
*/
public FLClientStatus shareSecrets() {
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets " +
"======================")));
FLClientStatus curStatus;
// RequestShareSecrets
curStatus = requestShareSecrets();
@ -573,14 +813,22 @@ public class CipherClient {
return curStatus;
}
/**
* reconstructSecrets round of secure aggregation
*
* @return round execution result
*/
public FLClientStatus reconstructSecrets() {
LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret ========================"));
LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret " +
"========================"));
FLClientStatus curStatus;
// GetClientList
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList,
cUVKeys);
while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME);
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList
, cUVKeys);
}
if (curStatus == FLClientStatus.RESTART) {
nextRequestTime = clientListReq.getNextRequestTime();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,10 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import org.bouncycastle.crypto.BlockCipher;
import org.bouncycastle.crypto.engines.AESEngine;
import org.bouncycastle.crypto.prng.SP800SecureRandomBuilder;
import java.io.File;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
@ -26,28 +33,94 @@ import java.util.logging.Logger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Define basic global methods used in federated learning task.
*
* @since 2021-06-30
*/
public class Common {
/**
* Global logger title.
*/
public static final String LOG_TITLE = "<FLClient> ";
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "albert"));
public static String generateUrl(boolean useHttps, boolean useElb, String ip, int port, int serverNum) {
if (useHttps) {
ip = "https://" + ip + ":";
} else {
ip = "http://" + ip + ":";
/**
* The list of trust flName.
*/
public static final List<String> FL_NAME_TRUST_LIST = new ArrayList<>(Arrays.asList("lenet", "albert"));
/**
* The list of trust ssl protocol.
*/
public static final List<String> SSL_PROTOCOL_TRUST_LIST = new ArrayList<>(Arrays.asList("TLSv1.3", "TLSv1.2"));
/**
* The tag when server is in safe mode.
*/
public static final String SAFE_MOD = "The cluster is in safemode.";
/**
* The tag when server is not ready.
*/
public static final String JOB_NOT_AVAILABLE = "The server's training job is disabled or finished.";
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
private static SecureRandom secureRandom;
/**
* Generate the URL for device-sever interaction
*
* @param ifUseElb whether a client randomly sends a request to a server address within a specified range.
* @param serverNum number of servers that can send requests.
* @param domainName the URL for device-sever interaction set by user.
* @return the URL for device-sever interaction.
*/
public static String generateUrl(boolean ifUseElb, int serverNum, String domainName) {
if (serverNum <= 0) {
LOGGER.severe(Common.addTag("[generateUrl] the input argument <serverNum> is not valid: <= 0, it should " +
"be > 0, please check!"));
throw new IllegalArgumentException();
}
String url;
if (useElb) {
if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
LOGGER.severe(Common.addTag("[generateUrl] the input argument <domainName> is null or not valid, it " +
"should be like as https://...... or http://...... , please check!"));
throw new IllegalArgumentException();
}
if (ifUseElb) {
if (domainName.split("//")[1].split(":").length < 2) {
LOGGER.severe(Common.addTag("[generateUrl] the format of <domainName> is not valid, it should be like" +
" as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, please check!"));
throw new IllegalArgumentException();
}
String ip = domainName.split("//")[1].split(":")[0];
int port = Integer.parseInt(domainName.split("//")[1].split(":")[1]);
if (!Common.checkIP(ip)) {
LOGGER.severe(Common.addTag("[generateUrl] the <ip> split from domainName is not valid, domainName " +
"should be like as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, " +
"please check!"));
throw new IllegalArgumentException();
}
if (!Common.checkPort(port)) {
LOGGER.severe(Common.addTag("[generateUrl] the <port> split from domainName is not valid, domainName " +
"should be like as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, " +
"please check!"));
throw new IllegalArgumentException();
}
String tag = domainName.split("//")[0] + "//";
Random rand = new Random();
int randomNum = rand.nextInt(100000) % serverNum + port;
url = ip + String.valueOf(randomNum);
url = tag + ip + ":" + String.valueOf(randomNum);
} else {
url = ip + String.valueOf(port);
url = domainName;
}
return url;
}
/**
* Store weight name of classifier to a list.
*
* @param classifierWeightName the list to store weight name of classifier.
*/
public static void setClassifierWeightName(List<String> classifierWeightName) {
classifierWeightName.add("albert.pooler.weight");
classifierWeightName.add("albert.pooler.bias");
@ -56,6 +129,11 @@ public class Common {
LOGGER.info(addTag("classifierWeightName size: " + classifierWeightName.size()));
}
/**
* Store weight name of albert network to a list.
*
* @param albertWeightName the list to store weight name of albert network.
*/
public static void setAlbertWeightName(List<String> albertWeightName) {
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.weight");
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.bias");
@ -78,32 +156,67 @@ public class Common {
LOGGER.info(addTag("albertWeightName size: " + albertWeightName.size()));
}
/**
* Check whether the flName set by user is in the trust list.
*
* @param flName the model name set by user.
* @return boolean value, true indicates the flName set by user is valid, false indicates the flName set by user
* is not valid.
*/
public static boolean checkFLName(String flName) {
return (flNameTrustList.contains(flName));
return (FL_NAME_TRUST_LIST.contains(flName));
}
/**
* Check whether the sslProtocol set by user is in the trust list.
*
* @param sslProtocol the ssl protocol set by user.
* @return boolean value, true indicates the sslProtocol set by user is valid, false indicates the sslProtocol
* set by user is not valid.
*/
public static boolean checkSSLProtocol(String sslProtocol) {
return (SSL_PROTOCOL_TRUST_LIST.contains(sslProtocol));
}
/**
* The program waits for the specified time and then to continue.
*
* @param millis the waiting time (ms).
*/
public static void sleep(long millis) {
try {
Thread.sleep(millis); //1000 milliseconds is one second.
Thread.sleep(millis); // 1000 milliseconds is one second.
} catch (InterruptedException ex) {
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
Thread.currentThread().interrupt();
}
}
/**
* Get the waiting time for repeated requests.
*
* @param nextRequestTime the timestamp return from server.
* @return the waiting time for repeated requests.
*/
public static long getWaitTime(String nextRequestTime) {
Date date = new Date();
long currentTime = date.getTime();
long waitTime = 0;
if (!("").equals(nextRequestTime)) {
long waitTime = 0L;
if (!(nextRequestTime == null || nextRequestTime.isEmpty())) {
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
}
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime));
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " +
currentTime));
LOGGER.info(addTag("[getWaitTime] waitTime: " + waitTime));
return waitTime;
}
/**
* Get start time.
*
* @param tag the tag added to the logger.
* @return start time.
*/
public static long startTime(String tag) {
Date startDate = new Date();
long startTime = startDate.getTime();
@ -111,6 +224,12 @@ public class Common {
return startTime;
}
/**
* Get end time.
*
* @param start the start time.
* @param tag the tag added to the logger.
*/
public static void endTime(long start, String tag) {
Date endDate = new Date();
long endTime = endDate.getTime();
@ -118,53 +237,182 @@ public class Common {
LOGGER.info(addTag("[interval time] <" + tag + "> interval time(ms): " + (endTime - start)));
}
/**
* Add specified tag to the message.
*
* @param message the message need to add tag.
* @return the message after adding tag.
*/
public static String addTag(String message) {
return LOG_TITLE + message;
}
public static boolean isSafeMod(byte[] message, String safeModTag) {
return (new String(message)).contains(safeModTag);
/**
* Check whether the server is ready based on the message returned by the server.
*
* @param message the message returned by the server..
* @return boolean value, true indicates the server is ready, false indicates the server is not ready.
*/
public static boolean isSeverReady(byte[] message) {
if (message == null) {
LOGGER.severe(Common.addTag("[isSeverReady] the input argument <message> is null, please check!"));
throw new IllegalArgumentException();
}
String messageStr = new String(message);
if (messageStr.contains(SAFE_MOD)) {
LOGGER.info(Common.addTag("[isSeverReady] " + SAFE_MOD + ", need wait some time and request again"));
return false;
} else if (messageStr.contains(JOB_NOT_AVAILABLE)) {
LOGGER.info(Common.addTag("[isSeverReady] " + JOB_NOT_AVAILABLE + ", need wait some time and request " +
"again"));
return false;
} else {
return true;
}
}
public static String getRealPath (String path) {
LOGGER.info(addTag("[original path] " + path));
/**
* Convert a user-set path to a standard path.
*
* @param path the user-set path.
* @return the standard path.
*/
public static String getRealPath(String path) {
if (path == null) {
LOGGER.severe(Common.addTag("[getRealPath] the input argument <path> is null, please check!"));
throw new IllegalArgumentException();
}
LOGGER.info(addTag("[getRealPath] original path: " + path));
String[] paths = path.split(",");
for (int i = 0; i < paths.length; i++) {
LOGGER.info(addTag("[original path " + i + "] " + paths[i]));
if (paths[i] == null) {
LOGGER.severe(Common.addTag("[getRealPath] the paths[i] is null, please check"));
throw new IllegalArgumentException();
}
LOGGER.info(addTag("[getRealPath] original path " + i + ": " + paths[i]));
File file = new File(paths[i]);
try {
paths[i] = file.getCanonicalPath();
} catch (IOException e) {
LOGGER.severe(addTag("[checkPath] catch IOException in file.getCanonicalPath(): " + e.getMessage()));
throw new RuntimeException();
LOGGER.severe(addTag("[getRealPath] catch IOException in file.getCanonicalPath(): " + e.getMessage()));
throw new IllegalArgumentException();
}
}
path = String.join(",", Arrays.asList(paths));
LOGGER.info(addTag("[real path] " + path));
return path;
String realPath = String.join(",", Arrays.asList(paths));
LOGGER.info(addTag("[getRealPath] real path: " + realPath));
return realPath;
}
/**
* Check whether the path set by user exists.
*
* @param path the path set by user.
* @return boolean value, true indicates the path is exist, false indicates the path is not exist
*/
public static boolean checkPath(String path) {
boolean tag = true;
if (path == null) {
LOGGER.severe(Common.addTag("[checkPath] the input argument <path> is null, please check!"));
return false;
}
String[] paths = path.split(",");
for (int i = 0; i < paths.length; i++) {
if (paths[i] == null) {
LOGGER.severe(Common.addTag("[checkPath] the paths[i] is null, please check"));
return false;
}
LOGGER.info(addTag("[check path " + i + "] " + paths[i]));
File file = new File(paths[i]);
if (!file.exists()) {
tag = false;
LOGGER.severe(Common.addTag("[checkPath] the path is not exist, please check"));
return false;
}
}
return tag;
return true;
}
/**
* Check whether the ip set by user is valid.
*
* @param ip the ip set by user.
* @return boolean value, true indicates the ip is valid, false indicates the ip is not valid.
*/
public static boolean checkIP(String ip) {
String regex = "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])";
if (ip == null) {
LOGGER.severe(Common.addTag("[checkIP] the input argument <ip> is null, please check!"));
throw new IllegalArgumentException();
}
String regex = "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])[.]" +
"(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.]" +
"(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.]" +
"(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])";
Pattern pattern = Pattern.compile(regex);
Matcher matcher = pattern.matcher(ip);
return matcher.matches();
}
/**
* Check whether the port set by user is valid.
*
* @param port the port set by user.
* @return boolean value, true indicates the port is valid, false indicates the port is not valid.
*/
public static boolean checkPort(int port) {
return port > 0 && port <= 65535;
}
/**
* Obtain secure random.
*
* @return the secure random.
*/
public static SecureRandom getSecureRandom() {
if (secureRandom == null) {
LOGGER.severe(Common.addTag("[setSecureRandom] the parameter secureRandom is null, please set it before " +
"use"));
throw new IllegalArgumentException();
}
return secureRandom;
}
/**
* Set the secure random to parameter secureRandom of the class Common.
*
* @param secureRandom the secure random.
*/
public static void setSecureRandom(SecureRandom secureRandom) {
if (secureRandom == null) {
LOGGER.severe(Common.addTag("[setSecureRandom] the input parameter secureRandom is null, please check"));
throw new IllegalArgumentException();
}
Common.secureRandom = secureRandom;
}
/**
* Obtain fast secure random.
*
* @return the fast secure random.
*/
public static SecureRandom getFastSecureRandom() {
try {
LOGGER.info(Common.addTag("[getFastSecureRandom] start create fastSecureRandom"));
long start = System.currentTimeMillis();
SecureRandom blockingRandom = SecureRandom.getInstanceStrong();
boolean ifPredictionResistant = true;
BlockCipher cipher = new AESEngine();
int cipherLen = 256;
int entropyBitsRequired = 384;
byte[] nonce = null;
boolean ifForceReseed = false;
SecureRandom fastRandom = new SP800SecureRandomBuilder(blockingRandom, ifPredictionResistant)
.setEntropyBitsRequired(entropyBitsRequired)
.buildCTR(cipher, cipherLen, nonce, ifForceReseed);
fastRandom.nextInt();
LOGGER.info(Common.addTag("[getFastSecureRandom] finish create fastSecureRandom"));
LOGGER.info(Common.addTag("[getFastSecureRandom] cost time: " + (System.currentTimeMillis() - start)));
return fastRandom;
} catch (NoSuchAlgorithmException e) {
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException: " + e.getMessage()));
throw new IllegalArgumentException();
}
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,13 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
/**
* The early stop mod.
*
* @since 2021-06-30
*/
public enum EarlyStopMod {
LOSS_DIFF,
LOSS_ABS,
WEIGHT_DIFF,
NOT_EARLY_STOP
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,12 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
/**
* Security encryption level.
*
* @since 2021-06-30
*/
public enum EncryptLevel {
PW_ENCRYPT,
DP_ENCRYPT,
NOT_ENCRYPT
}
}

View File

@ -1,21 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
/**
* The status code of federated learning.
*
* @since 2021-06-30
*/
public enum FLClientStatus {
SUCCESS,
FAILED,

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
package com.mindspore.flclient;
import static com.mindspore.flclient.FLParameter.TIME_OUT;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.MediaType;
@ -24,12 +26,6 @@ import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
@ -39,34 +35,40 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.TIME_OUT;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
/**
* Define the communication interface.
*
* @since 2021-06-30
*/
public class FLCommunication implements IFLCommunication {
private static int timeOut;
private static boolean ssl = false;
private static String env;
private static SSLSocketFactory sslSocketFactory;
private static X509TrustManager x509TrustManager;
private FLParameter flParameter = FLParameter.getInstance();
private static boolean ifCertificateVerify = false;
private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8");
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
private OkHttpClient client;
private static volatile FLCommunication communication;
private FLParameter flParameter = FLParameter.getInstance();
private OkHttpClient client;
private FLCommunication() {
if (flParameter.getTimeOut() != 0) {
timeOut = flParameter.getTimeOut();
} else {
timeOut = TIME_OUT;
}
ssl = flParameter.isUseSSL();
ifCertificateVerify = flParameter.isUseSSL();
client = getOkHttpClient();
}
private static OkHttpClient getOkHttpClient() {
X509TrustManager trustManager = new X509TrustManager() {
@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[]{};
@ -89,14 +91,15 @@ public class FLCommunication implements IFLCommunication {
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
if (ssl) {
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(), SSLSocketFactoryTools.getInstance().getmTrustManager());
if (ifCertificateVerify) {
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(),
SSLSocketFactoryTools.getInstance().getmTrustManager());
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
} else {
final SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
final javax.net.ssl.SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
builder.sslSocketFactory(sslSocketFactory, trustManager);
sslContext.init(null, trustAllCerts, Common.getSecureRandom());
final SSLSocketFactory sslFactory = sslContext.getSocketFactory();
builder.sslSocketFactory(sslFactory, trustManager);
builder.hostnameVerifier(new HostnameVerifier() {
@Override
public boolean verify(String arg0, SSLSession arg1) {
@ -104,14 +107,18 @@ public class FLCommunication implements IFLCommunication {
}
});
}
return builder.build();
} catch (NoSuchAlgorithmException | KeyManagementException e) {
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + e.getMessage()));
throw new RuntimeException(e);
} catch (NoSuchAlgorithmException | KeyManagementException ex) {
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + ex.getMessage()));
throw new IllegalArgumentException(ex);
}
}
/**
* Get the singleton object of the class FLCommunication.
*
* @return the singleton object of the class FLCommunication.
*/
public static FLCommunication getInstance() {
FLCommunication localRef = communication;
if (localRef == null) {
@ -138,6 +145,9 @@ public class FLCommunication implements IFLCommunication {
if (!response.isSuccessful()) {
throw new IOException("Unexpected code " + response);
}
if (response.body() == null) {
throw new IOException("the returned response is null");
}
return response.body().bytes();
}
@ -159,11 +169,10 @@ public class FLCommunication implements IFLCommunication {
}
@Override
public void onFailure(Call call, IOException e) {
asyncCallBack.onFailure(e);
public void onFailure(Call call, IOException ioException) {
asyncCallBack.onFailure(ioException);
call.cancel();
}
});
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,19 +13,40 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import java.util.logging.Logger;
/**
* Define job result callback function.
*
* @since 2021-06-30
*/
public class FLJobResultCallback implements IFLJobResultCallback {
private static final Logger LOGGER = Logger.getLogger(FLJobResultCallback.class.toString());
/**
* Called at the end of an iteration for Fl job
*
* @param modelName the name of model
* @param iterationSeq Iteration number
* @param resultCode Status Code
*/
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) {
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode));
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " +
iterationSeq + " resultCode: " + resultCode));
}
/**
* Called on completion for Fl job
*
* @param modelName the name of model
* @param iterationCount total Iteration numbers
* @param resultCode Status Code
*/
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode));
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " +
iterationCount + " resultCode: " + resultCode));
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,11 +16,15 @@
package com.mindspore.flclient;
import com.mindspore.flclient.cipher.BaseUtil;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.CipherPublicParams;
import mindspore.schema.FLPlan;
import mindspore.schema.ResponseCode;
@ -36,18 +40,21 @@ import java.util.Map;
import java.util.TreeMap;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
/**
* Defining the general process of federated learning tasks.
*
* @since 2021-06-30
*/
public class FLLiteClient {
private static final Logger LOGGER = Logger.getLogger(FLLiteClient.class.toString());
private FLCommunication flCommunication;
private static int iteration = 0;
private static Map<String, float[]> mapBeforeTrain;
private double dpNormClipFactor = 1.0d;
private double dpNormClipAdapt = 0.05d;
private FLCommunication flCommunication;
private FLClientStatus status;
private int retCode;
private static int iteration = 0;
private int iterations = 1;
private int epochs = 1;
private int batchSize = 16;
@ -55,22 +62,21 @@ public class FLLiteClient {
private byte[] prime;
private int featureSize;
private int trainDataSize;
private double dpEps = 100;
private double dpDelta = 0.01;
public double dpNormClipFactor = 1.0;
public double dpNormClipAdapt = 0.05;
private double dpEps = 100d;
private double dpDelta = 0.01d;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private SecureProtocol secureProtocol = new SecureProtocol();
private static Map<String, float[]> mapBeforeTrain;
private String nextRequestTime;
/**
* Defining a constructor of teh class FLLiteClient.
*/
public FLLiteClient() {
flCommunication = FLCommunication.getInstance();
}
public int setGlobalParameters(ResponseFLJob flJob) {
private int setGlobalParameters(ResponseFLJob flJob) {
FLPlan flPlan = flJob.flPlanConfig();
if (flPlan == null) {
LOGGER.severe(Common.addTag("[startFLJob] the FLPlan get from server is null"));
@ -90,14 +96,22 @@ public class FLLiteClient {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
TrainLenet trainLenet = TrainLenet.getInstance();
trainLenet.setBatchSize(batchSize);
} else {
LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid"));
return -1;
}
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize));
CipherPublicParams cipherPublicParams = flPlan.cipher();
if (cipherPublicParams == null) {
LOGGER.severe(Common.addTag("[startFLJob] the cipherPublicParams returned from server is null"));
return -1;
}
String encryptLevel = cipherPublicParams.encryptType();
if ("".equals(encryptLevel) || encryptLevel.isEmpty()) {
LOGGER.severe(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server is null, set the encryptLevel to NOT_ENCRYPT "));
if (encryptLevel == null || encryptLevel.isEmpty()) {
LOGGER.severe(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server is null, set the " +
"encryptLevel to NOT_ENCRYPT "));
localFLParameter.setEncryptLevel(EncryptLevel.NOT_ENCRYPT.toString());
} else {
localFLParameter.setEncryptLevel(encryptLevel);
@ -113,10 +127,10 @@ public class FLLiteClient {
}
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
if (minSecretNum <= 0) {
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid: <=0"));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid:" +
" <=0"));
return -1;
}
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
break;
case DP_ENCRYPT:
dpEps = cipherPublicParams.dpEps();
@ -124,53 +138,97 @@ public class FLLiteClient {
dpNormClipFactor = cipherPublicParams.dpNormClip();
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " + dpNormClipFactor));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " +
dpNormClipFactor));
break;
default:
LOGGER.info(Common.addTag("[startFLJob] NotEncrypt, do not set parameter for Encrypt"));
LOGGER.info(Common.addTag("[startFLJob] NOT_ENCRYPT, do not set parameter for Encrypt"));
}
return 0;
}
/**
* Obtain retCode returned by server.
*
* @return the retCode returned by server.
*/
public int getRetCode() {
return retCode;
}
/**
* Obtain current iteration returned by server.
*
* @return the current iteration returned by server.
*/
public int getIteration() {
return iteration;
}
/**
* Obtain total iterations for the task returned by server.
*
* @return the total iterations for the task returned by server.
*/
public int getIterations() {
return iterations;
}
public int getEpochs() {
return epochs;
}
public int getBatchSize() {
return batchSize;
}
/**
* Obtain the returned timestamp for next request from server.
*
* @return the timestamp for next request.
*/
public String getNextRequestTime() {
return nextRequestTime;
}
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
/**
* set the size of train date set.
*
* @param trainDataSize the size of train date set.
*/
public void setTrainDataSize(int trainDataSize) {
this.trainDataSize = trainDataSize;
}
public FLClientStatus checkStatus() {
return this.status;
/**
* Obtain the dpNormClipFactor.
*
* @return the dpNormClipFactor.
*/
public double getDpNormClipFactor() {
return dpNormClipFactor;
}
/**
* Obtain the dpNormClipAdapt.
*
* @return the dpNormClipAdapt.
*/
public double getDpNormClipAdapt() {
return dpNormClipAdapt;
}
/**
* Set the dpNormClipAdapt.
*
* @param dpNormClipAdapt the dpNormClipAdapt.
*/
public void setDpNormClipAdapt(double dpNormClipAdapt) {
this.dpNormClipAdapt = dpNormClipAdapt;
}
/**
* Send serialized request message of startFLJob to server.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus startFLJob() {
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify server===================================="));
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify " +
"server===================================="));
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
StartFLJob startFLJob = StartFLJob.getInstance();
Date date = new Date();
long time = date.getTime();
@ -179,8 +237,9 @@ public class FLLiteClient {
long start = Common.startTime("single startFLJob");
LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length));
byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[startFLJob] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(message)) {
LOGGER.info(Common.addTag("[startFLJob] the server is not ready now, need wait some time and request " +
"again"));
status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
@ -193,14 +252,15 @@ public class FLLiteClient {
status = judgeStartFLJob(startFLJob, responseDataBuf);
retCode = responseDataBuf.retcode();
} catch (IOException e) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage()));
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " +
e.getMessage()));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
return status;
}
public FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
private FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
iteration = responseDataBuf.iteration();
FLClientStatus response = startFLJob.doResponse(responseDataBuf);
status = response;
@ -218,6 +278,10 @@ public class FLLiteClient {
break;
case RESTART:
FLPlan flPlan = responseDataBuf.flPlanConfig();
if (flPlan == null) {
LOGGER.severe(Common.addTag("[startFLJob] the flPlan returned from server is null"));
return FLClientStatus.FAILED;
}
iterations = flPlan.iterations();
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
nextRequestTime = responseDataBuf.nextReqTime();
@ -226,14 +290,21 @@ public class FLLiteClient {
LOGGER.severe(Common.addTag("[startFLJob] startFLJob failed"));
break;
default:
LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range <SUCCESS, WAIT, FAILED, Restart>"));
LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range " +
"<SUCCESS, WAIT, FAILED, Restart>"));
status = FLClientStatus.FAILED;
}
return status;
}
/**
* Define the training process.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus localTrain() {
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "===================================="));
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration +
"===================================="));
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ALBERT)) {
@ -254,12 +325,22 @@ public class FLLiteClient {
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
} else {
LOGGER.severe(Common.addTag("[train] the flName is not valid"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
return status;
}
/**
* Send serialized request message of updateModel to server.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus updateModel() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
UpdateModel updateModelBuf = UpdateModel.getInstance();
byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize);
if (updateModelBuf.getStatus() == FLClientStatus.FAILED) {
@ -270,8 +351,9 @@ public class FLLiteClient {
long start = Common.startTime("single updateModel");
LOGGER.info(Common.addTag("[updateModel] the request message length: " + updateModelBuffer.length));
byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[updateModel] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(message)) {
LOGGER.info(Common.addTag("[updateModel] the server is not ready now, need wait some time and request" +
" again"));
status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
@ -288,23 +370,31 @@ public class FLLiteClient {
}
LOGGER.info(Common.addTag("[updateModel] get response from server ok!"));
} catch (IOException e) {
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage()));
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " +
e.getMessage()));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
}
return status;
}
/**
* Send serialized request message of getModel to server.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus getModel() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration);
try {
long start = Common.startTime("single getModel");
LOGGER.info(Common.addTag("[getModel] the request message length: " + buffer.length));
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(message)) {
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
"again"));
status = FLClientStatus.WAIT;
return status;
}
@ -327,6 +417,12 @@ public class FLLiteClient {
return status;
}
/**
* Obtain the weight of the model before training.
*
* @param map a map to store the weight of the model.
* @return the weight.
*/
public static synchronized Map<String, float[]> getOldMapCopy(Map<String, float[]> map) {
if (mapBeforeTrain == null) {
Map<String, float[]> copyMap = new TreeMap<>();
@ -334,7 +430,8 @@ public class FLLiteClient {
float[] data = map.get(key);
int dataLen = data.length;
float[] weights = new float[dataLen];
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) &&
(key.indexOf("learning") < 0)) {
for (int j = 0; j < dataLen; j++) {
float weight = data[j];
weights[j] = weight;
@ -348,7 +445,8 @@ public class FLLiteClient {
float[] data = map.get(key);
float[] copyData = mapBeforeTrain.get(key);
int dataLen = data.length;
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) &&
(key.indexOf("learning") < 0)) {
for (int j = 0; j < dataLen; j++) {
copyData[j] = data[j];
}
@ -358,18 +456,25 @@ public class FLLiteClient {
return mapBeforeTrain;
}
/**
* Obtain pairwise mask and individual mask.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus getFeatureMask() {
FLClientStatus curStatus;
switch (localFLParameter.getEncryptLevel()) {
case PW_ENCRYPT:
LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" + localFLParameter.getEncryptLevel().toString() + ">"));
LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" +
localFLParameter.getEncryptLevel().toString() + ">"));
secureProtocol.setPWParameter(iteration, minSecretNum, prime, featureSize);
curStatus = secureProtocol.pwCreateMask();
if (curStatus == FLClientStatus.RESTART) {
nextRequestTime = secureProtocol.getNextRequestTime();
}
retCode = secureProtocol.getRetCode();
LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" + localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" +
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
return curStatus;
case DP_ENCRYPT:
Map<String, float[]> map = new HashMap<String, float[]>();
@ -388,7 +493,7 @@ public class FLLiteClient {
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[Encrypt] set parameters for DPEncrypt!"));
LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
return FLClientStatus.SUCCESS;
case NOT_ENCRYPT:
retCode = ResponseCode.SUCCEED;
@ -401,6 +506,11 @@ public class FLLiteClient {
}
}
/**
* Reconstruct the secrets used for unmasking model weights.
*
* @return current status code in client.
*/
public FLClientStatus unMasking() {
FLClientStatus curStatus;
switch (localFLParameter.getEncryptLevel()) {
@ -413,7 +523,7 @@ public class FLLiteClient {
}
return curStatus;
case DP_ENCRYPT:
LOGGER.info(Common.addTag("[Encrypt] DPEncrypt do not need unmasking"));
LOGGER.info(Common.addTag("[Encrypt] DP_ENCRYPT do not need unmasking"));
retCode = ResponseCode.SUCCEED;
return FLClientStatus.SUCCESS;
case NOT_ENCRYPT:
@ -427,18 +537,26 @@ public class FLLiteClient {
}
}
/**
* Evaluate model after getting model from server.
*
* @return the status code in client.
*/
public FLClientStatus evaluateModel() {
status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("===================================evaluate model after getting model from server==================================="));
LOGGER.info(Common.addTag("===================================evaluate model after getting model from " +
"server==================================="));
if (flParameter.getFlName().equals(ALBERT)) {
float acc = 0;
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlInferBert alInferBert = AlInferBert.getInstance();
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
flParameter.getIdsFile(), true);
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the return dataSize<=0"));
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the " +
"return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
@ -447,47 +565,66 @@ public class FLLiteClient {
} else {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile());
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
flParameter.getIdsFile());
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0"));
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the " +
"return dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
acc = alTrainBert.evalModel();
}
if (acc == Float.NaN) {
if (Float.isNaN(acc)) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <evalModel>: the return acc is NAN"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() +
" idsFile: " + flParameter.getIdsFile()));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0], flParameter.getTestDataset().split(",")[1]);
if (flParameter.getTestDataset().split(",").length < 2) {
LOGGER.severe(Common.addTag("[evaluate] the set testDataPath for lenet is not valid, should be the " +
"format of <data.bin,label.bin> "));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0],
flParameter.getTestDataset().split(",")[1]);
if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return dataSize<=0"));
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return " +
"dataSize<=0"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
float acc = trainLenet.evalModel();
if (acc == Float.NaN) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc is NAN"));
if (Float.isNaN(acc)) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc" +
" is NAN"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset().split(",")[0] + " labelPath: " + flParameter.getTestDataset().split(",")[1]));
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
flParameter.getTestDataset().split(",")[0] + " labelPath: " +
flParameter.getTestDataset().split(",")[1]));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
}
return status;
}
/**
* @param dataPath, train or test dataset and label set
* Set date path.
*
* @param dataPath, train or test dataset and label set.
* @return date size.
*/
public int setInput(String dataPath) {
retCode = ResponseCode.SUCCEED;
@ -496,15 +633,18 @@ public class FLLiteClient {
if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance();
dataSize = alTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " " +
"vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
if (dataPath.split(",").length < 2) {
LOGGER.info(Common.addTag("[set input] the set dataPath for lenet is not valid, should be the format of <data.bin,label.bin>"));
LOGGER.severe(Common.addTag("[set input] the set dataPath for lenet is not valid, should be the " +
"format of <data.bin,label.bin> "));
return -1;
}
dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]);
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " + +dataSize + " labelPath: " + dataPath.split(",")[1]));
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " +
dataSize + " labelPath: " + dataPath.split(",")[1]));
}
if (dataSize <= 0) {
retCode = ResponseCode.RequestError;
@ -513,36 +653,48 @@ public class FLLiteClient {
return dataSize;
}
/**
* Initialization session.
*
* @return the status code in client.
*/
public FLClientStatus initSession() {
int tag = 0;
retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
"Train Session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
"is -1"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " +
"Create inference Session============="));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
"Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
}
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is " +
"-1"));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
}
@Override
protected void finalize() {
/**
* Free session.
*/
protected void freeSession() {
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("===========free train session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
@ -558,5 +710,4 @@ public class FLLiteClient {
SessionUtil.free(trainLenet.getTrainSession());
}
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,22 +13,36 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import java.util.logging.Logger;
package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import java.util.Arrays;
import java.util.UUID;
import java.util.logging.Logger;
/**
* Defines global parameters used during federated learning and these parameters are provided for users to set.
*
* @since 2021-06-30
*/
public class FLParameter {
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
/**
* The timeout interval for communication on the device.
*/
public static final int TIME_OUT = 100;
/**
* The waiting time of repeated requests.
*/
public static final int SLEEP_TIME = 1000;
private static volatile FLParameter flParameter;
private String hostName;
private String domainName;
private String certPath;
private boolean useHttps = false;
private String trainDataset;
private String vocabFile = "null";
private String idsFile = "null";
@ -37,22 +51,21 @@ public class FLParameter {
private String trainModelPath;
private String inferModelPath;
private String clientID;
private String ip;
private int port;
private boolean useSSL = false;
private int timeOut;
private int sleepTime;
private boolean useElb = false;
private boolean ifUseElb = false;
private int serverNum = 1;
private boolean timer = true;
private int timeWindow = 6000;
private int reRequestNum = timeWindow / SLEEP_TIME + 1;
private static volatile FLParameter flParameter;
private FLParameter() {}
private FLParameter() {
clientID = UUID.randomUUID().toString();
}
/**
* Get the singleton object of the class FLParameter.
*
* @return the singleton object of the class FLParameter.
*/
public static FLParameter getInstance() {
FLParameter localRef = flParameter;
if (localRef == null) {
@ -66,95 +79,100 @@ public class FLParameter {
return localRef;
}
public String getHostName() {
if ("".equals(hostName) || hostName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <hostName> is null, please set it before use"));
throw new RuntimeException();
public String getDomainName() {
if (domainName == null || domainName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is null or empty, please set it " +
"before use"));
throw new IllegalArgumentException();
}
return hostName;
return domainName;
}
public void setHostName(String hostName) {
this.hostName = hostName;
public void setDomainName(String domainName) {
if (domainName == null || domainName.isEmpty() || (!("https".equals(domainName.split(":")[0]) || "http".equals(domainName.split(":")[0])))) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is not valid, it should be like " +
"as https://...... or http://......, please check it before set"));
throw new IllegalArgumentException();
}
this.domainName = domainName;
}
public String getCertPath() {
if ("".equals(certPath) || certPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null, please set it before use"));
throw new RuntimeException();
if (certPath == null || certPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null or empty, please set it " +
"before use"));
throw new IllegalArgumentException();
}
return certPath;
}
public void setCertPath(String certPath) {
certPath = Common.getRealPath(certPath);
if (Common.checkPath(certPath)) {
this.certPath = certPath;
String realCertPath = Common.getRealPath(certPath);
if (Common.checkPath(realCertPath)) {
this.certPath = realCertPath;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it before set"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it " +
"before set"));
throw new IllegalArgumentException();
}
}
public boolean isUseHttps() {
return useHttps;
}
public void setUseHttps(boolean useHttps) {
this.useHttps = useHttps;
}
public String getTrainDataset() {
if ("".equals(trainDataset) || trainDataset.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null, please set it before use"));
throw new RuntimeException();
if (trainDataset == null || trainDataset.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null or empty, please set " +
"it before use"));
throw new IllegalArgumentException();
}
return trainDataset;
}
public void setTrainDataset(String trainDataset) {
trainDataset = Common.getRealPath(trainDataset);
if (Common.checkPath(trainDataset)) {
this.trainDataset = trainDataset;
String realTrainDataset = Common.getRealPath(trainDataset);
if (Common.checkPath(realTrainDataset)) {
this.trainDataset = realTrainDataset;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it before set"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it " +
"before set"));
throw new IllegalArgumentException();
}
}
public String getVocabFile() {
if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before " +
"use"));
throw new IllegalArgumentException();
}
return vocabFile;
}
public void setVocabFile(String vocabFile) {
vocabFile = Common.getRealPath(vocabFile);
if (Common.checkPath(vocabFile)) {
this.vocabFile = vocabFile;
String realVocabFile = Common.getRealPath(vocabFile);
if (Common.checkPath(realVocabFile)) {
this.vocabFile = realVocabFile;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it before set"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it " +
"before set"));
throw new IllegalArgumentException();
}
}
public String getIdsFile() {
if ("null".equals(idsFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
throw new RuntimeException();
throw new IllegalArgumentException();
}
return idsFile;
}
public void setIdsFile(String idsFile) {
idsFile = Common.getRealPath(idsFile);
if (Common.checkPath(idsFile)) {
this.idsFile = idsFile;
String realIdsFile = Common.getRealPath(idsFile);
if (Common.checkPath(realIdsFile)) {
this.idsFile = realIdsFile;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it before set"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it " +
"before set"));
throw new IllegalArgumentException();
}
}
@ -163,19 +181,21 @@ public class FLParameter {
}
public void setTestDataset(String testDataset) {
testDataset = Common.getRealPath(testDataset);
if (Common.checkPath(testDataset)) {
this.testDataset = testDataset;
String realTestDataset = Common.getRealPath(testDataset);
if (Common.checkPath(realTestDataset)) {
this.testDataset = realTestDataset;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it before set"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it " +
"before set"));
throw new IllegalArgumentException();
}
}
public String getFlName() {
if ("".equals(flName) || flName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null, please set it before use"));
throw new RuntimeException();
if (flName == null || flName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null or empty, please set it " +
"before use"));
throw new IllegalArgumentException();
}
return flName;
}
@ -184,61 +204,50 @@ public class FLParameter {
if (Common.checkFLName(flName)) {
this.flName = flName;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in flNameTrustList, please check it before set"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in FL_NAME_TRUST_LIST: " +
Arrays.toString(Common.FL_NAME_TRUST_LIST.toArray(new String[0])) + ", please check it before " +
"set"));
throw new IllegalArgumentException();
}
}
public String getTrainModelPath() {
if ("".equals(trainModelPath) || trainModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null, please set it before use"));
throw new RuntimeException();
if (trainModelPath == null || trainModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null or empty, please set" +
" it before use"));
throw new IllegalArgumentException();
}
return trainModelPath;
}
public void setTrainModelPath(String trainModelPath) {
trainModelPath = Common.getRealPath(trainModelPath);
if (Common.checkPath(trainModelPath)) {
this.trainModelPath = trainModelPath;
String realTrainModelPath = Common.getRealPath(trainModelPath);
if (Common.checkPath(realTrainModelPath)) {
this.trainModelPath = realTrainModelPath;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check it before set"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check " +
"it before set"));
throw new IllegalArgumentException();
}
}
public String getInferModelPath() {
if ("".equals(inferModelPath) || inferModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null, please set it before use"));
throw new RuntimeException();
if (inferModelPath == null || inferModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null or empty, please set" +
" it before use"));
throw new IllegalArgumentException();
}
return inferModelPath;
}
public void setInferModelPath(String inferModelPath) {
inferModelPath = Common.getRealPath(inferModelPath);
if (Common.checkPath(inferModelPath)) {
this.inferModelPath = inferModelPath;
String realInferModelPath = Common.getRealPath(inferModelPath);
if (Common.checkPath(realInferModelPath)) {
this.inferModelPath = realInferModelPath;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check it before set"));
throw new RuntimeException();
}
}
public String getIp() {
if ("".equals(ip) || ip.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is null, please set it before use"));
throw new RuntimeException();
}
return ip;
}
public void setIp(String ip) {
if (Common.checkIP(ip)) {
this.ip = ip;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is not valid, please check it before set"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check " +
"it before set"));
throw new IllegalArgumentException();
}
}
@ -250,23 +259,6 @@ public class FLParameter {
this.useSSL = useSSL;
}
public int getPort() {
if (port == 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is null, please set it before use"));
throw new RuntimeException();
}
return port;
}
public void setPort(int port) {
if (Common.checkPort(port)) {
this.port = port;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is not valid, please check it before set"));
throw new RuntimeException();
}
}
public int getTimeOut() {
return timeOut;
}
@ -284,17 +276,18 @@ public class FLParameter {
}
public boolean isUseElb() {
return useElb;
return ifUseElb;
}
public void setUseElb(boolean useElb) {
this.useElb = useElb;
public void setUseElb(boolean ifUseElb) {
this.ifUseElb = ifUseElb;
}
public int getServerNum() {
if (serverNum <= 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> is <= 0, it should be > 0, please set it before use"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> <= 0, it should be > 0, please " +
"set it before use"));
throw new IllegalArgumentException();
}
return serverNum;
}
@ -303,40 +296,11 @@ public class FLParameter {
this.serverNum = serverNum;
}
public boolean isTimer() {
return timer;
}
public void setTimer(boolean timer) {
this.timer = timer;
}
public int getTimeWindow() {
return timeWindow;
}
public void setTimeWindow(int timeWindow) {
this.timeWindow = timeWindow;
}
public int getReRequestNum() {
return reRequestNum;
}
public void setReRequestNum(int reRequestNum) {
this.reRequestNum = reRequestNum;
}
public String getClientID() {
if ("".equals(clientID) || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null, please set it before use"));
throw new RuntimeException();
if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check"));
throw new IllegalArgumentException();
}
return clientID;
}
public void setClientID(String clientID) {
this.clientID = clientID;
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,13 +13,19 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestGetModel;
import mindspore.schema.ResponseCode;
@ -29,57 +35,30 @@ import java.util.ArrayList;
import java.util.Date;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
/**
* Define the serialization method, handle the response message returned from server for getModel request.
*
* @since 2021-06-30
*/
public class GetModel {
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
private static volatile GetModel getModel;
static {
System.loadLibrary("mindspore-lite-jni");
}
class RequestGetModelBuilder {
private FlatBufferBuilder builder;
private int nameOffset = 0;
private int iteration = 0;
private int timeStampOffset = 0;
public RequestGetModelBuilder() {
builder = new FlatBufferBuilder();
}
public RequestGetModelBuilder flName(String name) {
this.nameOffset = this.builder.createString(name);
return this;
}
public RequestGetModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timeStampOffset = builder.createString(String.valueOf(time));
return this;
}
public RequestGetModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
public byte[] build() {
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
builder.finish(root);
return builder.sizedByteArray();
}
}
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
private static volatile GetModel getModel;
private GetModel() {
}
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private GetModel() {
}
/**
* Get the singleton object of the class GetModel.
*
* @return the singleton object of the class GetModel.
*/
public static GetModel getInstance() {
GetModel localRef = getModel;
if (localRef == null) {
@ -93,7 +72,18 @@ public class GetModel {
return localRef;
}
/**
* Get a flatBuffer builder of RequestGetModel.
*
* @param name the model name.
* @param iteration current iteration of federated learning task.
* @return the flatBuffer builder of RequestGetModel in byte[] format.
*/
public byte[] getRequestGetModel(String name, int iteration) {
if (name == null || name.isEmpty()) {
LOGGER.severe(Common.addTag("[GetModel] the input parameter of <name> is null or empty, please check!"));
throw new IllegalArgumentException();
}
RequestGetModelBuilder builder = new RequestGetModelBuilder();
return builder.iteration(iteration).flName(name).time().build();
}
@ -107,6 +97,10 @@ public class GetModel {
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
@ -116,36 +110,46 @@ public class GetModel {
} else {
continue;
}
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength:" +
" " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------"));
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " +
"model-----------------"));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
@ -159,9 +163,14 @@ public class GetModel {
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
@ -174,7 +183,12 @@ public class GetModel {
return FLClientStatus.SUCCESS;
}
/**
* Handle the response message returned from server.
*
* @param responseDataBuf the response message returned from server.
* @return the status code corresponding to the response message.
*/
public FLClientStatus doResponse(ResponseGetModel responseDataBuf) {
LOGGER.info(Common.addTag("[getModel] ==========get model content is:================"));
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode()));
@ -186,13 +200,15 @@ public class GetModel {
switch (retCode) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[getModel] getModel response success"));
if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
status = parseResponseAlbert(responseDataBuf);
} else if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
status = parseResponseLenet(responseDataBuf);
} else {
LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert"));
throw new IllegalArgumentException();
}
return status;
case (ResponseCode.SucNotReady):
@ -211,4 +227,42 @@ public class GetModel {
}
}
class RequestGetModelBuilder {
private FlatBufferBuilder builder;
private int nameOffset = 0;
private int iteration = 0;
private int timeStampOffset = 0;
public RequestGetModelBuilder() {
builder = new FlatBufferBuilder();
}
private RequestGetModelBuilder flName(String name) {
if (name == null || name.isEmpty()) {
LOGGER.severe(Common.addTag("[GetModel] the input parameter of <name> is null or empty, please " +
"check!"));
throw new IllegalArgumentException();
}
this.nameOffset = this.builder.createString(name);
return this;
}
private RequestGetModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timeStampOffset = builder.createString(String.valueOf(time));
return this;
}
private RequestGetModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
private byte[] build() {
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
builder.finish(root);
return builder.sizedByteArray();
}
}
}

View File

@ -1,25 +1,42 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import java.io.IOException;
/**
* Define asynchronous communication call back interface.
*
* @since 2021-06-30
*/
public interface IAsyncCallBack {
public FLClientStatus onFailure(IOException exception);
/**
* Automatically invoked when the request fails.
*
* @param exception the catch exception.
* @return the status code in client.
*/
FLClientStatus onFailure(IOException exception);
public FLClientStatus onResponse(byte[] msg);
/**
* Automatically invoked when a response message is processed.
*
* @param msg the response message.
* @return the status code in client.
*/
FLClientStatus onResponse(byte[] msg);
}

View File

@ -1,33 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import java.util.concurrent.TimeoutException;
/**
* @author smurf
* Define basic communication interface.
*
* @since 2021-06-30
*/
public interface IFLCommunication {
/**
* Sets the timeout interval for communication on the device.
*
* @param timeout the timeout interval for communication on the device.
* @throws TimeoutException catch TimeoutException.
*/
void setTimeOut(int timeout) throws TimeoutException;
public void setTimeOut(int timeout) throws TimeoutException;
public byte[] syncRequest(String url, byte[] msg) throws Exception;
public void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception;
}
/**
* Synchronization request function.
*
* @param url the URL for device-sever interaction set by user.
* @param msg the message need to be sent to server.
* @return the response message.
* @throws Exception catch Exception.
*/
byte[] syncRequest(String url, byte[] msg) throws Exception;
/**
* Asynchronous request function.
*
* @param url the URL for device-sever interaction set by user.
* @param msg the message need to be sent to server.
* @param callBack the call back object.
* @throws Exception catch Exception.
*/
void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,22 +13,30 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
/**
* Define job result callback function interface.
*
* @since 2021-06-30
*/
public interface IFLJobResultCallback {
/**
* Called at the end of an iteration for Fl job
* @param modelName the name of model
* Called at the end of an iteration for Fl job
*
* @param modelName the name of model
* @param iterationSeq Iteration number
* @param resultCode Status Code
* @param resultCode Status Code
*/
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode);
void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode);
/**
* Called on completion for Fl job
* @param modelName the name of model
*
* @param modelName the name of model
* @param iterationCount total Iteration numbers
* @param resultCode Status Code
* @param resultCode Status Code
*/
public void onFlJobFinished(String modelName, int iterationCount, int resultCode);
void onFlJobFinished(String modelName, int iterationCount, int resultCode);
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,28 +13,60 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import org.bouncycastle.math.ec.rfc7748.X25519;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
/**
* Defines global parameters used internally during federated learning.
*
* @since 2021-06-30
*/
public class LocalFLParameter {
private static final Logger LOGGER = Logger.getLogger(LocalFLParameter.class.toString());
/**
* Seed length used to generate random perturbations
*/
public static final int SEED_SIZE = 32;
public static final int IVEC_LEN = 16;
/**
* The length of IV value
*/
public static final int I_VEC_LEN = 16;
/**
* The length of salt value
*/
public static final int SALT_SIZE = 32;
/**
* the key length
*/
public static final int KEY_LEN = X25519.SCALAR_SIZE;
/**
* The model name supported by federated learning tasks: "lenet".
*/
public static final String LENET = "lenet";
/**
* The model name supported by federated learning tasks: "albert".
*/
public static final String ALBERT = "albert";
private static volatile LocalFLParameter localFLParameter;
private List<String> classifierWeightName = new ArrayList<>();
private List<String> albertWeightName = new ArrayList<>();
private String flID;
private String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString();
private String earlyStopMod = EarlyStopMod.NOT_EARLY_STOP.toString();
private String serverMod = ServerMod.HYBRID_TRAINING.toString();
private String safeMod = "The cluster is in safemode.";
private static volatile LocalFLParameter localFLParameter;
private LocalFLParameter() {
// set classifierWeightName albertWeightName
@ -42,6 +74,11 @@ public class LocalFLParameter {
Common.setAlbertWeightName(albertWeightName);
}
/**
* Get the singleton object of the class LocalFLParameter.
*
* @return the singleton object of the class LocalFLParameter.
*/
public static LocalFLParameter getInstance() {
LocalFLParameter localRef = localFLParameter;
if (localRef == null) {
@ -57,8 +94,9 @@ public class LocalFLParameter {
public List<String> getClassifierWeightName() {
if (classifierWeightName.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please " +
"set it before use"));
throw new IllegalArgumentException();
}
return classifierWeightName;
}
@ -69,8 +107,9 @@ public class LocalFLParameter {
public List<String> getAlbertWeightName() {
if (albertWeightName.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please " +
"set it before use"));
throw new IllegalArgumentException();
}
return albertWeightName;
}
@ -80,14 +119,20 @@ public class LocalFLParameter {
}
public String getFlID() {
if ("".equals(flID) || flID == null) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before use"));
throw new RuntimeException();
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before " +
"use"));
throw new IllegalArgumentException();
}
return flID;
}
public void setFlID(String flID) {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please check it before " +
"set"));
throw new IllegalArgumentException();
}
this.flID = flID;
}
@ -96,6 +141,18 @@ public class LocalFLParameter {
}
public void setEncryptLevel(String encryptLevel) {
if (encryptLevel == null || encryptLevel.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is null, please check it " +
"before set"));
throw new IllegalArgumentException();
}
if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) &&
(!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) &&
(!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is " + encryptLevel + " ," +
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before set"));
throw new IllegalArgumentException();
}
this.encryptLevel = encryptLevel;
}
@ -104,6 +161,19 @@ public class LocalFLParameter {
}
public void setEarlyStopMod(String earlyStopMod) {
if (earlyStopMod == null || earlyStopMod.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is null, please check it " +
"before set"));
throw new IllegalArgumentException();
}
if ((!EarlyStopMod.NOT_EARLY_STOP.toString().equals(earlyStopMod)) &&
(!EarlyStopMod.LOSS_ABS.toString().equals(earlyStopMod)) &&
(!EarlyStopMod.LOSS_DIFF.toString().equals(earlyStopMod)) &&
(!EarlyStopMod.WEIGHT_DIFF.toString().equals(earlyStopMod))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is " + earlyStopMod + " ," +
" it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before set"));
throw new IllegalArgumentException();
}
this.earlyStopMod = earlyStopMod;
}
@ -112,14 +182,17 @@ public class LocalFLParameter {
}
public void setServerMod(String serverMod) {
if (serverMod == null || serverMod.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is null, please check it " +
"before set"));
throw new IllegalArgumentException();
}
if ((!ServerMod.HYBRID_TRAINING.toString().equals(serverMod)) &&
(!ServerMod.FEDERATED_LEARNING.toString().equals(serverMod))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is " + serverMod + " , it " +
"must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before set"));
throw new IllegalArgumentException();
}
this.serverMod = serverMod;
}
public String getSafeMod() {
return safeMod;
}
public void setSafeMod(String safeMod) {
this.safeMod = safeMod;
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,33 +13,73 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.security.InvalidKeyException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.SignatureException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.logging.Logger;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.FileInputStream;
import java.io.InputStream;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.SignatureException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.logging.Logger;
/**
* Define SSL socket factory tools for https communication.
*
* @since 2021-06-30
*/
public class SSLSocketFactoryTools {
private static final Logger LOGGER = Logger.getLogger(SSLSocketFactory.class.toString());
private static volatile SSLSocketFactoryTools sslSocketFactoryTools;
private FLParameter flParameter = FLParameter.getInstance();
private X509Certificate x509Certificate;
private SSLSocketFactory sslSocketFactory;
private SSLContext sslContext;
private MyTrustManager myTrustManager;
private static volatile SSLSocketFactoryTools sslSocketFactoryTools;
private final HostnameVerifier hostnameVerifier = new HostnameVerifier() {
@Override
public boolean verify(String hostname, SSLSession session) {
if (hostname == null || hostname.isEmpty()) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <hostname> is null or empty, " +
"please check!"));
throw new IllegalArgumentException();
}
if (session == null) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <session> is null, please " +
"check!"));
throw new IllegalArgumentException();
}
String domainName = flParameter.getDomainName();
if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the <domainName> is null or not valid, it should" +
" be like as https://...... , please check!"));
throw new IllegalArgumentException();
}
if (domainName.split("//")[1].split(":").length < 2) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the format of <domainName> is not valid, it " +
"should be like as https://127.0.0.1:6666 when setting <useSSL> to true, please check!"));
throw new IllegalArgumentException();
}
String ip = domainName.split("//")[1].split(":")[0];
return hostname.equals(ip);
}
};
private SSLSocketFactoryTools() {
initSslSocketFactory();
@ -52,14 +92,19 @@ public class SSLSocketFactoryTools {
myTrustManager = new MyTrustManager(x509Certificate);
sslContext.init(null, new TrustManager[]{
myTrustManager
}, new java.security.SecureRandom());
}, Common.getSecureRandom());
sslSocketFactory = sslContext.getSocketFactory();
} catch (Exception e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " + e.getMessage()));
} catch (NoSuchAlgorithmException | KeyManagementException ex) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " +
ex.getMessage()));
}
}
/**
* Get the singleton object of the class SSLSocketFactoryTools.
*
* @return the singleton object of the class SSLSocketFactoryTools.
*/
public static SSLSocketFactoryTools getInstance() {
SSLSocketFactoryTools localRef = sslSocketFactoryTools;
if (localRef == null) {
@ -73,29 +118,37 @@ public class SSLSocketFactoryTools {
return localRef;
}
public X509Certificate readCert(String assetName) {
InputStream inputStream = null;
try {
inputStream = new FileInputStream(assetName);
} catch (Exception e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of read inputStream in readCert: " + e.getMessage()));
private X509Certificate readCert(String assetName) {
if (assetName == null || assetName.isEmpty()) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <assetName> is null or empty, " +
"please check!"));
return null;
}
InputStream inputStream = null;
X509Certificate cert = null;
try {
inputStream = new FileInputStream(assetName);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
cert = (X509Certificate) cf.generateCertificate(inputStream);
} catch (Exception e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of creating CertificateFactory in readCert: " + e.getMessage()));
Certificate certificate = cf.generateCertificate(inputStream);
if (certificate instanceof X509Certificate) {
cert = (X509Certificate) certificate;
} else {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] cf.generateCertificate(inputStream) can not " +
"convert to X509Certificate"));
}
} catch (FileNotFoundException | CertificateException ex) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch FileNotFoundException or CertificateException " +
"when creating " +
"CertificateFactory in readCert: " + ex.getMessage()));
} finally {
try {
if (inputStream != null) {
inputStream.close();
}
} catch (Throwable ex) {
} catch (IOException ex) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch IOException: " + ex.getMessage()));
}
}
return cert;
}
@ -111,7 +164,6 @@ public class SSLSocketFactoryTools {
return myTrustManager;
}
private static final class MyTrustManager implements X509TrustManager {
X509Certificate cert;
@ -126,27 +178,25 @@ public class SSLSocketFactoryTools {
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
for (X509Certificate cert : chain) {
// Make sure that it hasn't expired.
cert.checkValidity();
// Verify the certificate's public key chain.
try {
cert.verify(((X509Certificate) this.cert).getPublicKey());
cert.verify(this.cert.getPublicKey());
} catch (NoSuchAlgorithmException e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage()));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
"NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage()));
} catch (InvalidKeyException e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch InvalidKeyException in checkServerTrusted: " + e.getMessage()));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
"InvalidKeyException in checkServerTrusted: " + e.getMessage()));
} catch (NoSuchProviderException e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchProviderException in checkServerTrusted: " + e.getMessage()));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
"NoSuchProviderException in checkServerTrusted: " + e.getMessage()));
} catch (SignatureException e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch SignatureException in checkServerTrusted: " + e.getMessage()));
throw new RuntimeException();
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
"SignatureException in checkServerTrusted: " + e.getMessage()));
}
LOGGER.info(Common.addTag("checkServerTrusted success!"));
LOGGER.info(Common.addTag("**********************checkServerTrusted success!**********************"));
}
}
@ -155,14 +205,4 @@ public class SSLSocketFactoryTools {
return new java.security.cert.X509Certificate[0];
}
}
private final HostnameVerifier hostnameVerifier = new HostnameVerifier() {
@Override
public boolean verify(String hostname, SSLSession session) {
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] server hostname: " + flParameter.getHostName()));
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] client request hostname: " + hostname));
return hostname.equals(flParameter.getHostName());
}
};
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,127 +13,179 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
/**
* Defines encryption and decryption methods.
*
* @since 2021-06-30
*/
public class SecureProtocol {
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
private static double deltaError = 1e-6d;
private static Map<String, float[]> modelMap;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int iteration;
private CipherClient cipher;
private CipherClient cipherClient;
private FLClientStatus status;
private float[] featureMask = new float[0];
private double dpEps;
private double dpDelta;
private double dpNormClip;
private static double deltaError = 1e-6;
private static Map<String, float[]> modelMap;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private int retCode;
/**
* Obtain current status code in client.
*
* @return current status code in client.
*/
public FLClientStatus getStatus() {
return status;
}
public float[] getFeatureMask() {
return featureMask;
}
/**
* Obtain retCode returned by server.
*
* @return the retCode returned by server.
*/
public int getRetCode() {
return retCode;
}
public SecureProtocol() {
}
/**
* Setting parameters for pairwise masking.
*
* @param iter current iteration of federated learning task.
* @param minSecretNum minimum number of secret fragments required to reconstruct a secret
* @param prime teh big prime number used to split secrets into pieces
* @param featureSize the total feature size in model
*/
public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) {
this.iteration = iter;
this.cipher = new CipherClient(iteration, minSecretNum, prime, featureSize);
}
public FLClientStatus setDPParameter(int iter, double diffEps,
double diffDelta, double diffNorm, Map<String, float[]> map) {
try {
this.iteration = iter;
this.dpEps = diffEps;
this.dpDelta = diffDelta;
this.dpNormClip = diffNorm;
this.modelMap = map;
status = FLClientStatus.SUCCESS;
} catch (Exception e) {
LOGGER.severe(Common.addTag("[DPEncrypt] catch Exception in setDPParameter: " + e.getMessage()));
status = FLClientStatus.FAILED;
if (prime == null || prime.length == 0) {
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <prime> is null, please check!"));
throw new IllegalArgumentException();
}
return status;
this.iteration = iter;
this.cipherClient = new CipherClient(iteration, minSecretNum, prime, featureSize);
}
/**
* Setting parameters for differential privacy.
*
* @param iter current iteration of federated learning task.
* @param diffEps privacy budget eps of DP mechanism.
* @param diffDelta privacy budget delta of DP mechanism.
* @param diffNorm normClip factor of DP mechanism.
* @param map model weights.
* @return the status code corresponding to the response message.
*/
public FLClientStatus setDPParameter(int iter, double diffEps, double diffDelta, double diffNorm, Map<String,
float[]> map) {
this.iteration = iter;
this.dpEps = diffEps;
this.dpDelta = diffDelta;
this.dpNormClip = diffNorm;
this.modelMap = map;
return FLClientStatus.SUCCESS;
}
/**
* Obtain the feature names that needed to be encrypted.
*
* @return the feature names that needed to be encrypted.
*/
public ArrayList<String> getEncryptFeatureName() {
return encryptFeatureName;
}
/**
* Set the parameter encryptFeatureName.
*
* @param encryptFeatureName the feature names that needed to be encrypted.
*/
public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) {
this.encryptFeatureName = encryptFeatureName;
}
/**
* Obtain the returned timestamp for next request from server.
*
* @return the timestamp for next request.
*/
public String getNextRequestTime() {
return cipher.getNextRequestTime();
return cipherClient.getNextRequestTime();
}
/**
* Generate pairwise mask and individual mask.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus pwCreateMask() {
LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "==============");
LOGGER.info(String.format("[PairWiseMask] ==============request flID: %s ==============",
localFLParameter.getFlID()));
// round 0
status = cipher.exchangeKeys();
retCode = cipher.getRetCode();
LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============");
status = cipherClient.exchangeKeys();
retCode = cipherClient.getRetCode();
LOGGER.info(String.format("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: %s ",
"============", status));
if (status != FLClientStatus.SUCCESS) {
return status;
}
// round 1
try {
status = cipher.shareSecrets();
retCode = cipher.getRetCode();
LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "=============");
} catch (Exception e) {
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
status = FLClientStatus.FAILED;
}
status = cipherClient.shareSecrets();
retCode = cipherClient.getRetCode();
LOGGER.info(String.format("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: %s ",
"=============", status));
if (status != FLClientStatus.SUCCESS) {
return status;
}
// round2
try {
featureMask = cipher.doubleMaskingWeight();
retCode = cipher.getRetCode();
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
} catch (Exception e) {
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
status = FLClientStatus.FAILED;
featureMask = cipherClient.doubleMaskingWeight();
if (featureMask == null || featureMask.length <= 0) {
LOGGER.severe(Common.addTag("[Encrypt] the returned featureMask from cipherClient.doubleMaskingWeight" +
" is null, please check!"));
return FLClientStatus.FAILED;
}
retCode = cipherClient.getRetCode();
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
return status;
}
/**
* Add the pairwise mask and individual mask to model weights.
*
* @param builder the FlatBufferBuilder object used for serialization model weights.
* @param trainDataSize trainDataSize tne size of train data set.
* @return the serialized model weights after adding masks.
*/
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
if (featureMask == null || featureMask.length == 0) {
LOGGER.severe("[Encrypt] feature mask is null, please check");
return new int[0];
}
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
@ -142,6 +194,9 @@ public class SecureProtocol {
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
} else {
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
throw new IllegalArgumentException();
}
int featureSize = encryptFeatureName.size();
int[] featuresMap = new int[featureSize];
@ -149,9 +204,13 @@ public class SecureProtocol {
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
LOGGER.info("[Encrypt] feature name: " + key + " feature size: " + data.length);
LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
if (maskIndex >= featureMask.length) {
LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check");
return new int[0];
}
float maskData = rawData * trainDataSize + featureMask[maskIndex];
maskIndex += 1;
data[j] = maskData;
@ -164,17 +223,23 @@ public class SecureProtocol {
return featuresMap;
}
/**
* Reconstruct the secrets used for unmasking model weights.
*
* @return current status code in client.
*/
public FLClientStatus pwUnmasking() {
status = cipher.reconstructSecrets(); // round3
retCode = cipher.getRetCode();
LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "=============");
status = cipherClient.reconstructSecrets(); // round3
retCode = cipherClient.getRetCode();
LOGGER.info(String.format("[Encrypt] =============GetClientList+SendReconstructSecret: %s =============",
status));
return status;
}
private static float calculateErf(double x) {
double result = 0;
private static float calculateErf(double erfInput) {
double result = 0d;
int segmentNum = 10000;
double deltaX = x / segmentNum;
double deltaX = erfInput / segmentNum;
result += 1;
for (int i = 1; i < segmentNum; i++) {
result += 2 * Math.exp(-Math.pow(deltaX * i, 2));
@ -183,33 +248,36 @@ public class SecureProtocol {
return (float) (result * deltaX / Math.pow(Math.PI, 0.5));
}
private static double calculatePhi(double t) {
return 0.5 * (1.0 + calculateErf((t / Math.sqrt(2.0))));
private static double calculatePhi(double phiInput) {
return 0.5 * (1.0 + calculateErf((phiInput / Math.sqrt(2.0))));
}
private static double calculateBPositive(double eps, double s) {
return calculatePhi(Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
private static double calculateBPositive(double eps, double calInput) {
return calculatePhi(Math.sqrt(eps * calInput)) -
Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
}
private static double calculateBNegative(double eps, double s) {
return calculatePhi(-Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
private static double calculateBNegative(double eps, double calInput) {
return calculatePhi(-Math.sqrt(eps * calInput)) -
Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
}
private static double calculateSPositive(double eps, double targetDelta, double sInf, double sSup) {
double deltaSup = calculateBPositive(eps, sSup);
private static double calculateSPositive(double eps, double targetDelta, double initSInf, double initSSup) {
double deltaSup = calculateBPositive(eps, initSSup);
double sInf = initSInf;
double sSup = initSSup;
while (deltaSup <= targetDelta) {
sInf = sSup;
sSup = 2 * sInf;
deltaSup = calculateBPositive(eps, sSup);
}
double sMid = sInf + (sSup - sInf) / 2.0;
int iterMax = 1000;
int iters = 0;
while (true) {
double b = calculateBPositive(eps, sMid);
if (b <= targetDelta) {
if (targetDelta - b <= deltaError) {
double bPositive = calculateBPositive(eps, sMid);
if (bPositive <= targetDelta) {
if (targetDelta - bPositive <= deltaError) {
break;
} else {
sInf = sMid;
@ -226,8 +294,10 @@ public class SecureProtocol {
return sMid;
}
private static double calculateSNegative(double eps, double targetDelta, double sInf, double sSup) {
double deltaSup = calculateBNegative(eps, sSup);
private static double calculateSNegative(double eps, double targetDelta, double initSInf, double initSSup) {
double deltaSup = calculateBNegative(eps, initSSup);
double sInf = initSInf;
double sSup = initSSup;
while (deltaSup > targetDelta) {
sInf = sSup;
sSup = 2 * sInf;
@ -238,9 +308,9 @@ public class SecureProtocol {
int iterMax = 1000;
int iters = 0;
while (true) {
double b = calculateBNegative(eps, sMid);
if (b <= targetDelta) {
if (targetDelta - b <= deltaError) {
double bNegative = calculateBNegative(eps, sMid);
if (bNegative <= targetDelta) {
if (targetDelta - bNegative <= deltaError) {
break;
} else {
sSup = sMid;
@ -259,17 +329,26 @@ public class SecureProtocol {
private static double calculateSigma(double clipNorm, double eps, double targetDelta) {
double deltaZero = calculateBPositive(eps, 0);
double alpha = 1;
double alpha = 1d;
if (targetDelta > deltaZero) {
double s = calculateSPositive(eps, targetDelta, 0, 1);
alpha = Math.sqrt(1.0 + s / 2.0) - Math.sqrt(s / 2.0);
double sPositive = calculateSPositive(eps, targetDelta, 0, 1);
alpha = Math.sqrt(1.0 + sPositive / 2.0) - Math.sqrt(sPositive / 2.0);
} else if (targetDelta < deltaZero) {
double s = calculateSNegative(eps, targetDelta, 0, 1);
alpha = Math.sqrt(1.0 + s / 2.0) + Math.sqrt(s / 2.0);
double sNegative = calculateSNegative(eps, targetDelta, 0, 1);
alpha = Math.sqrt(1.0 + sNegative / 2.0) + Math.sqrt(sNegative / 2.0);
} else {
LOGGER.info(Common.addTag("[Encrypt] targetDelta = deltaZero"));
}
return alpha * clipNorm / Math.sqrt(2.0 * eps);
}
/**
* Add differential privacy mask to model weights.
*
* @param builder the FlatBufferBuilder object used for serialization model weights.
* @param trainDataSize tne size of train data set.
* @return the serialized model weights after adding masks.
*/
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
// get feature map
Map<String, float[]> map = new HashMap<String, float[]>();
@ -279,6 +358,9 @@ public class SecureProtocol {
} else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
} else {
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
throw new IllegalArgumentException();
}
Map<String, float[]> mapBeforeTrain = modelMap;
int featureSize = encryptFeatureName.size();
@ -286,19 +368,18 @@ public class SecureProtocol {
double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta);
LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "============="));
// prepare gaussian noise
SecureRandom random = new SecureRandom();
int randomInt = random.nextInt();
Random r = new Random(randomInt);
// calculate l2-norm of all layers' update array
double updateL2Norm = 0;
double updateL2Norm = 0d;
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
float[] dataBeforeTrain = mapBeforeTrain.get(key);
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
if (j >= dataBeforeTrain.length) {
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
return new int[0];
}
float rawDataBeforeTrain = dataBeforeTrain[j];
float updateData = rawData - rawDataBeforeTrain;
updateL2Norm += updateData * updateData;
@ -311,11 +392,26 @@ public class SecureProtocol {
int[] featuresMap = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
if (!map.containsKey(key)) {
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
return new int[0];
}
float[] data = map.get(key);
float[] data2 = new float[data.length];
if (!mapBeforeTrain.containsKey(key)) {
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");
return new int[0];
}
float[] dataBeforeTrain = mapBeforeTrain.get(key);
// prepare gaussian noise
SecureRandom secureRandom = Common.getSecureRandom();
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
if (j >= dataBeforeTrain.length) {
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
return new int[0];
}
float rawDataBeforeTrain = dataBeforeTrain[j];
float updateData = rawData - rawDataBeforeTrain;
@ -323,7 +419,7 @@ public class SecureProtocol {
updateData *= clipFactor;
// add noise
double gaussianNoise = r.nextGaussian() * gaussianSigma;
double gaussianNoise = secureRandom.nextGaussian() * gaussianSigma;
updateData += gaussianNoise;
data2[j] = rawDataBeforeTrain + updateData;
data2[j] = data2[j] * trainDataSize;
@ -335,5 +431,4 @@ public class SecureProtocol {
}
return featuresMap;
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,8 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
/**
* The training mode of federated learning.
*
* @since 2021-06-30
*/
public enum ServerMod {
FEDERATED_LEARNING,
HYBRID_TRAINING

View File

@ -1,6 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@ -13,13 +12,20 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FLPlan;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestFLJob;
import mindspore.schema.ResponseCode;
@ -28,15 +34,28 @@ import mindspore.schema.ResponseFLJob;
import java.util.ArrayList;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
/**
* StartFLJob
*
* @since 2021-08-25
*/
public class StartFLJob {
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString());
private static volatile StartFLJob startFLJob;
static {
System.loadLibrary("mindspore-lite-jni");
}
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString());
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int featureSize;
private String nextRequestTime;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private StartFLJob() {
}
class RequestStartFLJobBuilder {
private RequestFLJob requestFLJob;
@ -51,21 +70,53 @@ public class StartFLJob {
builder = new FlatBufferBuilder();
}
/**
* set flName
*
* @param name String
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder flName(String name) {
if (name == null || name.isEmpty()) {
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <name> is null or empty, please check!"));
throw new IllegalArgumentException();
}
this.nameOffset = this.builder.createString(name);
return this;
}
/**
* set id
*
* @param id String
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder id(String id) {
if (id == null || id.isEmpty()) {
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <id> is null or empty, please check!"));
throw new IllegalArgumentException();
}
this.idOffset = this.builder.createString(id);
return this;
}
/**
* set time
*
* @param timestamp long
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder time(long timestamp) {
this.timestampOffset = builder.createString(String.valueOf(timestamp));
return this;
}
/**
* set dataSize
*
* @param dataSize int
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder dataSize(int dataSize) {
// temp code need confirm
this.dataSize = dataSize;
@ -73,11 +124,22 @@ public class StartFLJob {
return this;
}
/**
* set iteration
*
* @param iteration iteration
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
/**
* build protobuffer
*
* @return byte[] data
*/
public byte[] build() {
int root = RequestFLJob.createRequestFLJob(this.builder, this.nameOffset, this.idOffset, this.iteration,
this.dataSize, this.timestampOffset);
@ -86,20 +148,11 @@ public class StartFLJob {
}
}
private static volatile StartFLJob startFLJob;
private FLClientStatus status;
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int featureSize;
private String nextRequestTime;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private StartFLJob() {
}
/**
* getInstance of StartFLJob
*
* @return StartFLJob instance
*/
public static StartFLJob getInstance() {
StartFLJob localRef = startFLJob;
if (localRef == null) {
@ -117,6 +170,14 @@ public class StartFLJob {
return nextRequestTime;
}
/**
* get request start FLJob
*
* @param dataSize dataSize
* @param iteration iteration
* @param time time
* @return byte[] data
*/
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) {
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
return builder.flName(flParameter.getFlName())
@ -135,6 +196,7 @@ public class StartFLJob {
return encryptFeatureName;
}
private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) {
int fmCount = flJob.featureMapLength();
encryptFeatureName.clear();
@ -149,6 +211,10 @@ public class StartFLJob {
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature);
@ -160,19 +226,23 @@ public class StartFLJob {
} else {
continue;
}
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
"weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------"));
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
"model-----------------"));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
inferFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
albertFeatureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
@ -182,16 +252,22 @@ public class StartFLJob {
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
"weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
featureMaps);
if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED;
@ -206,11 +282,16 @@ public class StartFLJob {
encryptFeatureName.clear();
for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname();
featureMaps.add(feature);
featureSize += feature.dataLength();
encryptFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " +
feature.weightFullname() + ", weightLength: " + feature.dataLength()));
}
int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
@ -223,7 +304,22 @@ public class StartFLJob {
return FLClientStatus.SUCCESS;
}
/**
* response res
*
* @param flJob ResponseFLJob
* @return FLClientStatus
*/
public FLClientStatus doResponse(ResponseFLJob flJob) {
if (flJob == null) {
LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null"));
return FLClientStatus.FAILED;
}
FLPlan flPlanConfig = flJob.flPlanConfig();
if (flPlanConfig == null) {
LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode()));
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
@ -236,11 +332,12 @@ public class StartFLJob {
switch (retCode) {
case (ResponseCode.SUCCEED):
localFLParameter.setServerMod(flJob.flPlanConfig().serverMode());
localFLParameter.setServerMod(flPlanConfig.serverMode());
if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
status = parseResponseAlbert(flJob);
} else if (LENET.equals(flParameter.getFlName())) {
}
if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
status = parseResponseLenet(flJob);
}
@ -256,8 +353,4 @@ public class StartFLJob {
return FLClientStatus.FAILED;
}
}
public FLClientStatus getStatus() {
return this.status;
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,35 +13,49 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.ResponseGetModel;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.ResponseGetModel;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
/**
* SyncFLJob defines the APIs for federated learning task.
* API flJobRun() for starting federated learning on the device, the API modelInference() for inference on the
* device, and the API getModel() for obtaining the latest model on the cloud.
*
* @since 2021-06-30
*/
public class SyncFLJob {
private static final Logger LOGGER = Logger.getLogger(SyncFLJob.class.toString());
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
private Map<String, float[]> oldFeatureMap;
public SyncFLJob() {
}
/**
* Starts a federated learning task on the device.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus flJobRun() {
Common.setSecureRandom(new SecureRandom());
localFLParameter.setFlID(flParameter.getClientID());
FLLiteClient client = new FLLiteClient();
FLClientStatus curStatus;
@ -58,17 +72,14 @@ public class SyncFLJob {
if (trainDataSize <= 0) {
LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0"));
curStatus = FLClientStatus.FAILED;
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
client.getRetCode());
break;
}
client.setTrainDataSize(trainDataSize);
// startFLJob
curStatus = client.startFLJob();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.startFLJob();
}
curStatus = startFLJob(client);
if (curStatus == FLClientStatus.RESTART) {
restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue;
@ -100,11 +111,7 @@ public class SyncFLJob {
LOGGER.info(Common.addTag("[train] train succeed"));
// updateModel
curStatus = client.updateModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.updateModel();
}
curStatus = updateModel(client);
if (curStatus == FLClientStatus.RESTART) {
restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue;
@ -124,11 +131,7 @@ public class SyncFLJob {
}
// getModel
curStatus = client.getModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.getModel();
}
curStatus = getModel(client);
if (curStatus == FLClientStatus.RESTART) {
restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue;
@ -142,7 +145,8 @@ public class SyncFLJob {
// evaluate model after getting model from server
if (flParameter.getTestDataset().equals("null")) {
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the combine model"));
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the model after getting" +
" model from server"));
} else {
curStatus = client.evaluateModel();
if (curStatus == FLClientStatus.FAILED) {
@ -151,33 +155,68 @@ public class SyncFLJob {
}
LOGGER.info(Common.addTag("[evaluate] evaluate succeed"));
}
LOGGER.info(Common.addTag("========================================================the total response of " + client.getIteration() + ": " + curStatus + "======================================================================"));
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
LOGGER.info(Common.addTag("========================================================the total response of "
+ client.getIteration() + ": " + curStatus +
"======================================================================"));
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
client.getRetCode());
} while (client.getIteration() < client.getIterations());
client.finalize();
client.freeSession();
LOGGER.info(Common.addTag("flJobRun finish"));
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
return curStatus;
}
private FLClientStatus startFLJob(FLLiteClient client) {
FLClientStatus curStatus = client.startFLJob();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.startFLJob();
}
return curStatus;
}
private FLClientStatus updateModel(FLLiteClient client) {
FLClientStatus curStatus = client.updateModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.updateModel();
}
return curStatus;
}
private FLClientStatus getModel(FLLiteClient client) {
FLClientStatus curStatus = client.getModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.getModel();
}
return curStatus;
}
private void updateDpNormClip(FLLiteClient client) {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
int currentIter = client.getIteration();
Map<String, float[]> fedFeatureMap = getFeatureMap();
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
if (fedWeightUpdateNorm == -1) {
LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " +
"-1, please check!"));
throw new IllegalArgumentException();
}
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
float newNormCLip = (float) client.dpNormClipFactor * fedWeightUpdateNorm;
float newNormCLip = (float) client.getDpNormClipFactor() * fedWeightUpdateNorm;
if (currentIter == 1) {
client.dpNormClipAdapt = newNormCLip;
client.setDpNormClipAdapt(newNormCLip);
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
} else {
if (newNormCLip < client.dpNormClipAdapt) {
client.dpNormClipAdapt = newNormCLip;
if (newNormCLip < client.getDpNormClipAdapt()) {
client.setDpNormClipAdapt(newNormCLip);
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
}
}
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.dpNormClipAdapt));
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.getDpNormClipAdapt()));
}
}
@ -190,11 +229,16 @@ public class SyncFLJob {
}
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
float updateL2Norm = 0;
float updateL2Norm = 0f;
for (String key : originalData.keySet()) {
float[] data = originalData.get(key);
float[] dataAfterUpdate = newData.get(key);
for (int j = 0; j < data.length; j++) {
if (j >= dataAfterUpdate.length) {
LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " +
"please check");
return -1;
}
float updateData = data[j] - dataAfterUpdate[j];
updateL2Norm += updateData * updateData;
}
@ -215,12 +259,22 @@ public class SyncFLJob {
return featureMap;
}
/**
* Starts an inference task on the device.
*
* @return the status code corresponding to the response message.
*/
public int[] modelInference() {
int[] labels = new int[0];
if (flParameter.getFlName().equals(ALBERT)) {
AlInferBert alInferBert = AlInferBert.getInstance();
LOGGER.info(Common.addTag("===========model inference============="));
labels = alInferBert.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile());
labels = alInferBert.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset(),
flParameter.getVocabFile(), flParameter.getIdsFile());
if (labels == null || labels.length == 0) {
LOGGER.severe("[model inference] the returned label from adInferBert.inferModel() is null, please " +
"check");
}
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
SessionUtil.free(alInferBert.getTrainSession());
LOGGER.info(Common.addTag("[model inference] inference finish"));
@ -228,49 +282,62 @@ public class SyncFLJob {
TrainLenet trainLenet = TrainLenet.getInstance();
LOGGER.info(Common.addTag("===========model inference============="));
labels = trainLenet.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset().split(",")[0]);
if (labels == null || labels.length == 0) {
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
}
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
SessionUtil.free(trainLenet.getTrainSession());
LOGGER.info(Common.addTag("[model inference] inference finish"));
}
if (labels.length == 0) {
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
}
return labels;
}
/**
* Obtains the latest model on the cloud.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus getModel() {
Common.setSecureRandom(Common.getFastSecureRandom());
int tag = 0;
FLClientStatus status = FLClientStatus.SUCCESS;
FLClientStatus status;
try {
if (flParameter.getFlName().equals(ALBERT)) {
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
flParameter.getTrainModelPath() + " Create Train Session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the " +
"return is -1"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " +
flParameter.getInferModelPath() + " Create inference Session============="));
AlInferBert alInferBert = AlInferBert.getInstance();
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) {
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
flParameter.getTrainModelPath() + " Create Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
}
if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
"is -1"));
return FLClientStatus.FAILED;
}
FLCommunication flCommunication = FLCommunication.getInstance();
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(message)) {
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
"again"));
status = FLClientStatus.WAIT;
return status;
}
@ -279,8 +346,8 @@ public class SyncFLJob {
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
status = getModelBuf.doResponse(responseDataBuf);
LOGGER.info(Common.addTag("[getModel] success!"));
} catch (Exception e) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage()));
} catch (Exception ex) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + ex.getMessage()));
status = FLClientStatus.FAILED;
}
if (flParameter.getFlName().equals(ALBERT)) {
@ -299,19 +366,16 @@ public class SyncFLJob {
}
private void waitSomeTime() {
if (flParameter.getSleepTime() != 0)
if (flParameter.getSleepTime() != 0) {
Common.sleep(flParameter.getSleepTime());
else
} else {
Common.sleep(SLEEP_TIME);
}
}
private void waitNextReqTime(String nextReqTime) {
if (flParameter.isTimer()) {
long waitTime = Common.getWaitTime(nextReqTime);
Common.sleep(waitTime);
} else {
waitSomeTime();
}
long waitTime = Common.getWaitTime(nextReqTime);
Common.sleep(waitTime);
}
private void restart(String tag, String nextReqTime, int iteration, int retcode) {
@ -322,7 +386,8 @@ public class SyncFLJob {
private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) {
LOGGER.info(Common.addTag(tag + " failed"));
LOGGER.info(Common.addTag("========================================================the total response of " + iteration + ": " + curStatus + "======================================================================"));
LOGGER.info(Common.addTag("=========================================the total response of " +
iteration + ": " + curStatus + "========================================="));
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
}
@ -334,53 +399,27 @@ public class SyncFLJob {
String flName = args[4];
String trainModelPath = args[5];
String inferModelPath = args[6];
String clientID = args[7];
String ip = args[8];
boolean useSSL = Boolean.parseBoolean(args[9]);
int port = Integer.parseInt(args[10]);
int timeWindow = Integer.parseInt(args[11]);
boolean useElb = Boolean.parseBoolean(args[12]);
int serverNum = Integer.parseInt(args[13]);
boolean useHttps = Boolean.parseBoolean(args[14]);
String certPath = args[15];
String task = args[16];
boolean useSSL = Boolean.parseBoolean(args[7]);
String domainName = args[8];
boolean useElb = Boolean.parseBoolean(args[9]);
int serverNum = Integer.parseInt(args[10]);
String certPath = args[11];
String task = args[12];
FLParameter flParameter = FLParameter.getInstance();
LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset));
LOGGER.info(Common.addTag("[args] vocabFile: " + vocabFile));
LOGGER.info(Common.addTag("[args] idsFile: " + idsFile));
LOGGER.info(Common.addTag("[args] testDataset: " + testDataset));
LOGGER.info(Common.addTag("[args] flName: " + flName));
LOGGER.info(Common.addTag("[args] trainModelPath: " + trainModelPath));
LOGGER.info(Common.addTag("[args] inferModelPath: " + inferModelPath));
LOGGER.info(Common.addTag("[args] clientID: " + clientID));
LOGGER.info(Common.addTag("[args] ip: " + ip));
LOGGER.info(Common.addTag("[args] useSSL: " + useSSL));
LOGGER.info(Common.addTag("[args] port: " + port));
LOGGER.info(Common.addTag("[args] timeWindow: " + timeWindow));
LOGGER.info(Common.addTag("[args] useElb: " + useElb));
LOGGER.info(Common.addTag("[args] serverNum: " + serverNum));
LOGGER.info(Common.addTag("[args] useHttps: " + useHttps));
LOGGER.info(Common.addTag("[args] certPath: " + certPath));
LOGGER.info(Common.addTag("[args] task: " + task));
flParameter.setClientID(clientID);
SyncFLJob syncFLJob = new SyncFLJob();
if (task.equals("train")) {
flParameter.setUseHttps(useHttps);
if (useSSL) {
flParameter.setCertPath(certPath);
}
flParameter.setHostName(ip);
flParameter.setTrainDataset(trainDataset);
flParameter.setFlName(flName);
flParameter.setTrainModelPath(trainModelPath);
flParameter.setTestDataset(testDataset);
flParameter.setInferModelPath(inferModelPath);
flParameter.setIp(ip);
flParameter.setUseSSL(useSSL);
flParameter.setPort(port);
flParameter.setTimeWindow(timeWindow);
flParameter.setDomainName(domainName);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
if (ALBERT.equals(flName)) {
@ -398,17 +437,14 @@ public class SyncFLJob {
}
syncFLJob.modelInference();
} else if (task.equals("getModel")) {
flParameter.setUseHttps(useHttps);
if (useSSL) {
flParameter.setCertPath(certPath);
}
flParameter.setHostName(ip);
flParameter.setFlName(flName);
flParameter.setTrainModelPath(trainModelPath);
flParameter.setInferModelPath(inferModelPath);
flParameter.setIp(ip);
flParameter.setUseSSL(useSSL);
flParameter.setPort(port);
flParameter.setDomainName(domainName);
flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum);
syncFLJob.getModel();

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,10 +16,15 @@
package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap;
import mindspore.schema.RequestUpdateModel;
import mindspore.schema.ResponseCode;
@ -31,145 +36,31 @@ import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
/**
* Define the serialization method, handle the response message returned from server for updateModel request.
*
* @since 2021-06-30
*/
public class UpdateModel {
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
private static volatile UpdateModel updateModel;
static {
System.loadLibrary("mindspore-lite-jni");
}
class RequestUpdateModelBuilder {
private RequestUpdateModel requestUM;
private FlatBufferBuilder builder;
private int fmOffset = 0;
private int nameOffset = 0;
private int idOffset = 0;
private int timestampOffset = 0;
private int iteration = 0;
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
public RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
builder = new FlatBufferBuilder();
this.encryptLevel = encryptLevel;
}
public RequestUpdateModelBuilder flName(String name) {
this.nameOffset = this.builder.createString(name);
return this;
}
public RequestUpdateModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timestampOffset = builder.createString(String.valueOf(time));
return this;
}
public RequestUpdateModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
public RequestUpdateModelBuilder id(String id) {
this.idOffset = this.builder.createString(id);
return this;
}
public RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
switch (encryptLevel) {
case PW_ENCRYPT:
try {
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is null, please check");
throw new RuntimeException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
return this;
} catch (Exception e) {
LOGGER.severe("[Encrypt] catch error in maskModel: " + e.getMessage());
throw new RuntimeException();
}
case DP_ENCRYPT:
try {
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is null, please check");
throw new RuntimeException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
return this;
} catch (Exception e) {
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
throw new RuntimeException();
}
case NOT_ENCRYPT:
default:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
}
int featureSize = encryptFeatureName.size();
int[] fmOffsets = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature size: " + data.length));
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
data[j] = data[j] * trainDataSize;
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
fmOffsets[i] = featureMap;
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
return this;
}
}
public byte[] build() {
RequestUpdateModel.startRequestUpdateModel(this.builder);
RequestUpdateModel.addFlName(builder, nameOffset);
RequestUpdateModel.addFlId(this.builder, idOffset);
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
RequestUpdateModel.addIteration(builder, this.iteration);
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
int root = RequestUpdateModel.endRequestUpdateModel(builder);
builder.finish(root);
return builder.sizedByteArray();
}
}
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private String nextRequestTime;
private FLClientStatus status;
private static volatile UpdateModel updateModel;
private UpdateModel() {
}
/**
* Get the singleton object of the class UpdateModel.
*
* @return the singleton object of the class UpdateModel.
*/
public static UpdateModel getInstance() {
UpdateModel localRef = updateModel;
if (localRef == null) {
@ -183,25 +74,35 @@ public class UpdateModel {
return localRef;
}
public String getNextRequestTime() {
return nextRequestTime;
}
public FLClientStatus getStatus() {
return status;
}
/**
* Get a flatBuffer builder of RequestUpdateModel.
*
* @param iteration current iteration of federated learning task.
* @param secureProtocol the object that defines encryption and decryption methods.
* @param trainDataSize the size of train date set.
* @return the flatBuffer builder of RequestUpdateModel in byte[] format.
*/
public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) {
RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel());
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID()).featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID())
.featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
}
/**
* Handle the response message returned from server.
*
* @param response the response message returned from server.
* @return the status code corresponding to the response message.
*/
public FLClientStatus doResponse(ResponseUpdateModel response) {
LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================"));
LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode()));
LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason()));
LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime()));
nextRequestTime = response.nextReqTime();
switch (response.retcode()) {
case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[updateModel] updateModel success"));
@ -213,8 +114,165 @@ public class UpdateModel {
LOGGER.warning(Common.addTag("[updateModel] catch RequestError or SystemError"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[updateModel]the return <retcode> from server is invalid: " + response.retcode()));
LOGGER.severe(Common.addTag("[updateModel]the return <retCode> from server is invalid: " +
response.retcode()));
return FLClientStatus.FAILED;
}
}
class RequestUpdateModelBuilder {
private RequestUpdateModel requestUM;
private FlatBufferBuilder builder;
private int fmOffset = 0;
private int nameOffset = 0;
private int idOffset = 0;
private int timestampOffset = 0;
private int iteration = 0;
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
private RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
builder = new FlatBufferBuilder();
this.encryptLevel = encryptLevel;
}
/**
* Serialize the element flName in RequestUpdateModel.
*
* @param name the model name.
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder flName(String name) {
if (name == null || name.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the parameter of <name> is null or empty, please check!"));
throw new IllegalArgumentException();
}
this.nameOffset = this.builder.createString(name);
return this;
}
/**
* Serialize the element timestamp in RequestUpdateModel.
*
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timestampOffset = builder.createString(String.valueOf(time));
return this;
}
/**
* Serialize the element iteration in RequestUpdateModel.
*
* @param iteration current iteration of federated learning task.
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
/**
* Serialize the element fl_id in RequestUpdateModel.
*
* @param id a number that uniquely identifies a client.
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder id(String id) {
if (id == null || id.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the parameter of <id> is null or empty, please check!"));
throw new IllegalArgumentException();
}
this.idOffset = this.builder.createString(id);
return this;
}
private RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
switch (encryptLevel) {
case PW_ENCRYPT:
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " +
"null, please check");
throw new IllegalArgumentException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
return this;
case DP_ENCRYPT:
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " +
"null, please check");
throw new IllegalArgumentException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
return this;
case NOT_ENCRYPT:
default:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
flParameter.getFlName()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
".convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
flParameter.getFlName()));
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
".convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
} else {
LOGGER.severe(Common.addTag("[updateModel] the flName is not valid"));
throw new IllegalArgumentException();
}
int featureSize = encryptFeatureName.size();
int[] fmOffsets = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
"size: " + data.length));
for (int j = 0; j < data.length; j++) {
data[j] = data[j] * trainDataSize;
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
fmOffsets[i] = featureMap;
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
return this;
}
}
/**
* Create a flatBuffer builder of RequestUpdateModel.
*
* @return the flatBuffer builder of RequestUpdateModel in byte[] format.
*/
private byte[] build() {
RequestUpdateModel.startRequestUpdateModel(this.builder);
RequestUpdateModel.addFlName(builder, nameOffset);
RequestUpdateModel.addFlId(this.builder, idOffset);
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
RequestUpdateModel.addIteration(builder, this.iteration);
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
int root = RequestUpdateModel.endRequestUpdateModel(builder);
builder.finish(root);
return builder.sizedByteArray();
}
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,21 +16,33 @@
package com.mindspore.flclient.cipher;
import static com.mindspore.flclient.LocalFLParameter.I_VEC_LEN;
import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
import com.mindspore.flclient.Common;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.io.UnsupportedEncodingException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.logging.Logger;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
/**
* Define encryption and decryption methods.
*
* @since 2021-06-30
*/
public class AESEncrypt {
private static final Logger LOGGER = Logger.getLogger(AESEncrypt.class.toString());
/**
* 128, 192 or 256
*/
private static final int KEY_SIZE = 256;
private static final int I_VEC_LEN = 16;
/**
* encrypt/decrypt algorithm name
@ -43,62 +55,145 @@ public class AESEncrypt {
private static final String CIPHER_MODE_CTR = "AES/CTR/NoPadding";
private static final String CIPHER_MODE_CBC = "AES/CBC/PKCS5PADDING";
private String CIPHER_MODE;
private String cipherMod;
private static final int RANDOM_LEN = KEY_SIZE / 8;
private String iVecS = "1111111111111111";
private byte[] iVec = iVecS.getBytes("utf-8");
public AESEncrypt(byte[] key, byte[] iVecIn, String mode) throws UnsupportedEncodingException {
/**
* Defining a Constructor of the class AESEncrypt.
*
* @param key the Key.
* @param mode the encryption Mode.
*/
public AESEncrypt(byte[] key, String mode) {
if (key == null) {
LOGGER.severe(Common.addTag("Key is null"));
return;
}
if (key.length != KEY_SIZE / 8) {
if (key.length != KEY_LEN) {
LOGGER.severe(Common.addTag("the length of key is not correct"));
return;
}
if (mode.contains("CBC")) {
CIPHER_MODE = CIPHER_MODE_CBC;
cipherMod = CIPHER_MODE_CBC;
} else if (mode.contains("CTR")) {
CIPHER_MODE = CIPHER_MODE_CTR;
cipherMod = CIPHER_MODE_CTR;
} else {
return;
}
if (iVecIn == null || iVecIn.length != I_VEC_LEN) {
return;
}
/**
* Defining the CBC encryption Mode.
*
* @param key the Key.
* @param data the data to be encrypted.
* @return the data to be encrypted.
*/
public byte[] encrypt(byte[] key, byte[] data) {
if (key == null) {
LOGGER.severe(Common.addTag("Key is null"));
return new byte[0];
}
if (data == null) {
LOGGER.severe(Common.addTag("data is null"));
return new byte[0];
}
try {
byte[] iVec = new byte[I_VEC_LEN];
SecureRandom secureRandom = Common.getSecureRandom();
secureRandom.nextBytes(iVec);
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(cipherMod);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
byte[] encrypted = cipher.doFinal(data);
byte[] encryptedAddIv = new byte[encrypted.length + iVec.length];
System.arraycopy(iVec, 0, encryptedAddIv, 0, iVec.length);
System.arraycopy(encrypted, 0, encryptedAddIv, iVec.length, encrypted.length);
return encryptedAddIv;
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException or " +
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
return new byte[0];
}
iVec = iVecIn;
}
public byte[] encrypt(byte[] key, byte[] data) throws Exception {
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
byte[] encrypted = cipher.doFinal(data);
String encryptResultStr = BaseUtil.byte2HexString(encrypted);
return encrypted;
/**
* Defining the CTR encryption Mode.
*
* @param key the Key.
* @param data the data to be encrypted.
* @param iVec the IV value.
* @return the data to be encrypted.
*/
public byte[] encryptCTR(byte[] key, byte[] data, byte[] iVec) {
if (key == null) {
LOGGER.severe(Common.addTag("Key is null"));
return new byte[0];
}
if (data == null) {
LOGGER.severe(Common.addTag("data is null"));
return new byte[0];
}
if (iVec == null || iVec.length != I_VEC_LEN) {
LOGGER.severe(Common.addTag("iVec is null or the length of iVec is not valid, it should be " + "I_VEC_LEN"
));
return new byte[0];
}
try {
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(cipherMod);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
return cipher.doFinal(data);
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
LOGGER.severe(Common.addTag("[encryptCTR] catch NoSuchAlgorithmException or " +
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
return new byte[0];
}
}
public byte[] encryptCTR(byte[] key, byte[] data) throws Exception {
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
byte[] encrypted = cipher.doFinal(data);
return encrypted;
/**
* Defining the decrypt method.
*
* @param key the Key.
* @param encryptDataAddIv the data to be decrypted.
* @return the data to be decrypted.
*/
public byte[] decrypt(byte[] key, byte[] encryptDataAddIv) {
if (key == null) {
LOGGER.severe(Common.addTag("Key is null"));
return new byte[0];
}
if (encryptDataAddIv == null) {
LOGGER.severe(Common.addTag("encryptDataAddIv is null"));
return new byte[0];
}
if (encryptDataAddIv.length <= I_VEC_LEN) {
LOGGER.severe(Common.addTag("the length of encryptDataAddIv is not valid: " + encryptDataAddIv.length +
", it should be > " + I_VEC_LEN));
return new byte[0];
}
try {
byte[] iVec = Arrays.copyOfRange(encryptDataAddIv, 0, I_VEC_LEN);
byte[] encryptData = Arrays.copyOfRange(encryptDataAddIv, I_VEC_LEN, encryptDataAddIv.length);
SecretKeySpec sKeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(cipherMod);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.DECRYPT_MODE, sKeySpec, iv);
return cipher.doFinal(encryptData);
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException or " +
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
return new byte[0];
}
}
public byte[] decrypt(byte[] key, byte[] encryptData) throws Exception {
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.DECRYPT_MODE, skeySpec, iv);
byte[] origin = cipher.doFinal(encryptData);
return origin;
}
}

View File

@ -1,18 +1,19 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* <p>
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
*
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import java.io.UnsupportedEncodingException;
@ -21,14 +22,23 @@ import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
/**
* Define conversion methods between basic data types.
*
* @since 2021-06-30
*/
public class BaseUtil {
private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
public BaseUtil() {
}
private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B',
'C', 'D', 'E', 'F'};
/**
* Convert byte[] to String in hexadecimal format.
*
* @param bytes the byte[] object.
* @return the String object converted from byte[].
*/
public static String byte2HexString(byte[] bytes) {
if (null == bytes) {
if (bytes == null) {
return null;
} else if (bytes.length == 0) {
return "";
@ -36,14 +46,20 @@ public class BaseUtil {
char[] chars = new char[bytes.length * 2];
for (int i = 0; i < bytes.length; ++i) {
int b = bytes[i];
chars[i * 2] = HEX_DIGITS[(b & 240) >> 4];
chars[i * 2 + 1] = HEX_DIGITS[b & 15];
int byteNum = bytes[i];
chars[i * 2] = HEX_DIGITS[(byteNum & 240) >> 4];
chars[i * 2 + 1] = HEX_DIGITS[byteNum & 15];
}
return new String(chars);
}
}
/**
* Convert String in hexadecimal format to byte[].
*
* @param str the String object.
* @return the byte[] converted from String object.
*/
public static byte[] hexString2ByteArray(String str) {
int length = str.length() / 2;
byte[] bytes = new byte[length];
@ -58,8 +74,13 @@ public class BaseUtil {
return bytes;
}
/**
* Convert byte[] to BigInteger.
*
* @param bytes the byte[] object.
* @return the BigInteger object converted from byte[].
*/
public static BigInteger byteArray2BigInteger(byte[] bytes) {
BigInteger bigInteger = BigInteger.ZERO;
for (int i = 0; i < bytes.length; ++i) {
int intI = bytes[i];
@ -72,6 +93,13 @@ public class BaseUtil {
return bigInteger;
}
/**
* Convert String to BigInteger.
*
* @param str the String object.
* @return the BigInteger object converted from String object.
* @throws UnsupportedEncodingException if the encoding is not supported.
*/
public static BigInteger string2BigInteger(String str) throws UnsupportedEncodingException {
StringBuilder res = new StringBuilder();
byte[] bytes = String.valueOf(str).getBytes("UTF-8");
@ -83,14 +111,20 @@ public class BaseUtil {
return bigInteger;
}
public static String bigInteger2String(BigInteger bigInteger) throws UnsupportedEncodingException {
/**
* Convert BigInteger to String.
*
* @param bigInteger the BigInteger object.
* @return the String object converted from BigInteger.
*/
public static String bigInteger2String(BigInteger bigInteger) {
StringBuilder res = new StringBuilder();
List<Integer> lists = new ArrayList<>();
BigInteger bi = bigInteger;
BigInteger DIV = BigInteger.valueOf(256);
BigInteger div = BigInteger.valueOf(256);
while (bi.compareTo(BigInteger.ZERO) > 0) {
lists.add(bi.mod(DIV).intValue());
bi = bi.divide(DIV);
lists.add(bi.mod(div).intValue());
bi = bi.divide(div);
}
for (int i = lists.size() - 1; i >= 0; --i) {
res.append((char) (int) (lists.get(i)));
@ -98,13 +132,19 @@ public class BaseUtil {
return res.toString();
}
public static byte[] bigInteger2byteArray(BigInteger bigInteger) throws UnsupportedEncodingException {
/**
* Convert BigInteger to byte[].
*
* @param bigInteger the BigInteger object.
* @return the byte[] object converted from BigInteger.
*/
public static byte[] bigInteger2byteArray(BigInteger bigInteger) {
List<Integer> lists = new ArrayList<>();
BigInteger bi = bigInteger;
BigInteger DIV = BigInteger.valueOf(256);
BigInteger div = BigInteger.valueOf(256);
while (bi.compareTo(BigInteger.ZERO) > 0) {
lists.add(bi.mod(DIV).intValue());
bi = bi.divide(DIV);
lists.add(bi.mod(div).intValue());
bi = bi.divide(div);
}
byte[] res = new byte[lists.size()];
for (int i = lists.size() - 1; i >= 0; --i) {
@ -113,13 +153,19 @@ public class BaseUtil {
return res;
}
/**
* Convert Integer to byte[].
*
* @param num the Integer object.
* @return the byte[] object converted from Integer.
*/
public static byte[] integer2byteArray(Integer num) {
List<Integer> lists = new ArrayList<>();
Integer bi = num;
Integer DIV = 256;
Integer div = 256;
while (bi > 0) {
lists.add(bi % DIV);
bi = bi / DIV;
lists.add(bi % div);
bi = bi / div;
}
byte[] res = new byte[lists.size()];
for (int i = lists.size() - 1; i >= 0; --i) {
@ -128,8 +174,13 @@ public class BaseUtil {
return res;
}
/**
* Convert byte[] to Integer.
*
* @param bytes the byte[] object.
* @return the Integer object converted from byte[].
*/
public static Integer byteArray2Integer(byte[] bytes) {
Integer num = 0;
for (int i = 0; i < bytes.length; ++i) {
int intI = bytes[i];

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,9 +14,13 @@
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLCommunication;
@ -25,23 +29,27 @@ import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import com.mindspore.flclient.cipher.struct.EncryptShare;
import com.mindspore.flclient.cipher.struct.NewArray;
import mindspore.schema.GetClientList;
import mindspore.schema.ResponseCode;
import mindspore.schema.ReturnClientList;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
/**
* Define the serialization method, handle the response message returned from server for GetClientList request.
*
* @since 2021-06-30
*/
public class ClientListReq {
private static final Logger LOGGER = Logger.getLogger(ClientListReq.class.toString());
private FLCommunication flCommunication;
private String nextRequestTime;
private FLParameter flParameter = FLParameter.getInstance();
@ -64,34 +72,63 @@ public class ClientListReq {
return retCode;
}
public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
/**
* Send serialized request message of GetClientList to server.
*
* @param iteration current iteration of federated learning task.
* @param u3ClientList list of clients successfully requested in UpdateModel round.
* @param decryptSecretsList list to store to decrypted secret fragments.
* @param returnShareList List of returned secret fragments from server.
* @param cuvKeys Keys used to decrypt secret fragments.
* @return the status code corresponding to the response message.
*/
public FLClientStatus getClientList(int iteration, List<String> u3ClientList,
List<DecryptShareSecrets> decryptSecretsList,
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
FlatBufferBuilder builder = new FlatBufferBuilder();
int id = builder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = builder.createString(dateTime);
int clientListRoot = GetClientList.createGetClientList(builder, id, iteration, time);
builder.finish(clientListRoot);
byte[] msg = builder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
try {
byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[getClientList] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getClientList] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
LOGGER.info(Common.addTag("getClientList responseData size: " + responseData.length));
ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer);
FLClientStatus status = judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);
return status;
} catch (Exception e) {
e.printStackTrace();
return judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);
} catch (IOException ex) {
LOGGER.severe(Common.addTag("[getClientList] unsolved error code in getClientList: catch IOException: " +
ex.getMessage()));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED;
}
}
public FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
/**
* Analyze the serialization message returned from server and perform corresponding processing.
*
* @param bufData Serialized message returned from server.
* @param u3ClientList list of clients successfully requested in UpdateModel round.
* @param decryptSecretsList list to store decrypted secret fragments.
* @param returnShareList List of returned secret fragments from server.
* @param cuvKeys Keys used to decrypt secret fragments.
* @return the status code corresponding to the response message.
*/
private FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList,
List<DecryptShareSecrets> decryptSecretsList,
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -109,18 +146,15 @@ public class ClientListReq {
String curFlId = bufData.clients(i);
u3ClientList.add(curFlId);
}
try {
decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys);
} catch (Exception e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
status = decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys);
return status;
case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetClientList again!"));
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
"GetClientList again!"));
return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob again"));
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob" +
" again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
@ -128,36 +162,66 @@ public class ClientListReq {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is invalid: " + retCode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is " +
"invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
public void decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) throws Exception {
private FLClientStatus decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList,
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
decryptSecretsList.clear();
int size = returnShareList.size();
if (size <= 0) {
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <returnShareList> is null"));
return FLClientStatus.FAILED;
}
if (cuvKeys.isEmpty()) {
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <cuvKeys> is null"));
return FLClientStatus.FAILED;
}
for (int i = 0; i < size; i++) {
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
EncryptShare encryptShare = returnShareList.get(i);
String vFlID = encryptShare.getFlID();
byte[] share = encryptShare.getShare().getArray();
byte[] iVecIn = new byte[IVEC_LEN];
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), iVecIn, "CBC");
if (!cuvKeys.containsKey(vFlID)) {
LOGGER.severe(Common.addTag("[PairWiseMask] the key <vFlID> is not in map <cuvKeys> "));
return FLClientStatus.FAILED;
}
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), "CBC");
byte[] decryptShare = aesEncrypt.decrypt(cuvKeys.get(vFlID), share);
if (decryptShare == null || decryptShare.length == 0) {
LOGGER.severe(Common.addTag("[decryptSecretShares] the return byte[] is null, please check!"));
return FLClientStatus.FAILED;
}
if (decryptShare.length < 4) {
LOGGER.severe(Common.addTag("[decryptSecretShares] the returned decryptShare is not valid: length is " +
"not right, please check!"));
return FLClientStatus.FAILED;
}
int sSize = (int) decryptShare[0];
int bSize = (int) decryptShare[1];
int sIndexLen = (int) decryptShare[2];
int bIndexLen = (int) decryptShare[3];
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen));
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen, 4 + sIndexLen + bIndexLen));
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen, 4 + sIndexLen + bIndexLen + sSize);
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize, 4 + sIndexLen + bIndexLen + sSize + bSize);
if (decryptShare.length < (4 + sIndexLen + bIndexLen + sSize + bSize)) {
LOGGER.severe(Common.addTag("[decryptSecretShares] the returned decryptShare is not valid: length is " +
"not right, please check!"));
return FLClientStatus.FAILED;
}
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen,
4 + sIndexLen + bIndexLen + sSize);
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize,
4 + sIndexLen + bIndexLen + sSize + bSize);
NewArray<byte[]> sSkVu = new NewArray<>();
sSkVu.setSize(sSize);
sSkVu.setArray(sSkUv);
NewArray bVu = new NewArray();
bVu.setSize(bSize);
bVu.setArray(bUv);
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen));
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen,
4 + sIndexLen + bIndexLen));
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
decryptShareSecrets.setFlID(vFlID);
decryptShareSecrets.setSSkVu(sSkVu);
decryptShareSecrets.setBVu(bVu);
@ -165,5 +229,6 @@ public class ClientListReq {
decryptShareSecrets.setIndexB(bIndex);
decryptSecretsList.add(decryptShareSecrets);
}
return FLClientStatus.SUCCESS;
}
}

View File

@ -1,6 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,6 +16,10 @@
package com.mindspore.flclient.cipher;
import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
import com.mindspore.flclient.Common;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator;
import org.bouncycastle.crypto.params.KeyParameter;
@ -25,39 +28,80 @@ import org.bouncycastle.math.ec.rfc7748.X25519;
import java.security.SecureRandom;
import java.util.logging.Logger;
/**
* Generate public-private key pairs and DH Keys.
*
* @since 2021-06-30
*/
public class KEYAgreement {
private static final Logger LOGGER = Logger.getLogger(KEYAgreement.class.toString());
private static final int PBKDF2_ITERATIONS = 10000;
private static final int SALT_SIZE = 32;
private static final int HASH_BIT_SIZE = 256;
private static final int KEY_LEN = X25519.SCALAR_SIZE;
private SecureRandom random = new SecureRandom();
private SecureRandom random = Common.getSecureRandom();
/**
* Generate private Key.
*
* @return the private Key.
*/
public byte[] generatePrivateKey() {
byte[] privateKey = new byte[KEY_LEN];
X25519.generatePrivateKey(random, privateKey);
return privateKey;
}
public byte[] generatePublicKey(byte[] privatekey) {
/**
* Use private Key to generate public Key.
*
* @param privateKey the private Key.
* @return the public Key.
*/
public byte[] generatePublicKey(byte[] privateKey) {
if (privateKey == null || privateKey.length == 0) {
LOGGER.severe(Common.addTag("privateKey is null"));
return new byte[0];
}
byte[] publicKey = new byte[KEY_LEN];
X25519.generatePublicKey(privatekey, 0, publicKey, 0);
X25519.generatePublicKey(privateKey, 0, publicKey, 0);
return publicKey;
}
public byte[] keyAgreement(byte[] privatekey, byte[] publicKey) {
/**
* Use private Key and public Key to generate DH Key.
*
* @param privateKey the private Key.
* @param publicKey the public Key.
* @return the DH Key.
*/
public byte[] keyAgreement(byte[] privateKey, byte[] publicKey) {
if (privateKey == null || privateKey.length == 0) {
LOGGER.severe(Common.addTag("privateKey is null"));
return new byte[0];
}
if (publicKey == null || publicKey.length == 0) {
LOGGER.severe(Common.addTag("publicKey is null"));
return new byte[0];
}
byte[] secret = new byte[KEY_LEN];
X25519.calculateAgreement(privatekey, 0, publicKey, 0, secret, 0);
X25519.calculateAgreement(privateKey, 0, publicKey, 0, secret, 0);
return secret;
}
/**
* Encrypt DH Key.
*
* @param password the DH Key.
* @param salt the salt value.
* @return encrypted DH Key.
*/
public byte[] getEncryptedPassword(byte[] password, byte[] salt) {
byte[] saltB = new byte[SALT_SIZE];
if (password == null || password.length == 0) {
LOGGER.severe(Common.addTag("password is null"));
return new byte[0];
}
PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest());
gen.init(password, saltB, PBKDF2_ITERATIONS);
byte[] dk = ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey();
return dk;
gen.init(password, salt, PBKDF2_ITERATIONS);
return ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey();
}
}

View File

@ -0,0 +1,115 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import com.mindspore.flclient.Common;
import java.security.SecureRandom;
import java.util.List;
import java.util.logging.Logger;
/**
* Define the basic method for generating pairwise mask and individual mask.
*
* @since 2021-06-30
*/
public class Masking {
private static final Logger LOGGER = Logger.getLogger(Masking.class.toString());
/**
* Random generate RNG algorithm name.
*/
private static final String RNG_ALGORITHM = "SHA1PRNG";
/**
* Generate individual mask.
*
* @param secret used to store individual mask.
* @return the int value, 0 indicates success, -1 indicates failed .
*/
public int getRandomBytes(byte[] secret) {
if (secret == null || secret.length == 0) {
LOGGER.severe(Common.addTag("[Masking] the input argument <secret> is null, please check!"));
return -1;
}
SecureRandom secureRandom = Common.getSecureRandom();
secureRandom.nextBytes(secret);
return 0;
}
/**
* Generate pairwise mask.
*
* @param noise used to store pairwise mask.
* @param length the length of individual mask.
* @param seed the seed for generate pairwise mask.
* @param iVec the IV value.
* @return the int value, 0 indicates success, -1 indicates failed .
*/
public int getMasking(List<Float> noise, int length, byte[] seed, byte[] iVec) {
if (length <= 0) {
LOGGER.severe(Common.addTag("[Masking] the input argument <length> is not valid: <= 0, please check!"));
return -1;
}
int intV = Integer.SIZE / 8;
int size = length * intV;
byte[] data = new byte[size];
for (int i = 0; i < size; i++) {
data[i] = 0;
}
AESEncrypt aesEncrypt = new AESEncrypt(seed, "CTR");
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data, iVec);
if (encryptCtr == null || encryptCtr.length == 0) {
LOGGER.severe(Common.addTag("[Masking] the return byte[] is null, please check!"));
return -1;
}
for (int i = 0; i < length; i++) {
int[] sub = new int[intV];
for (int j = 0; j < 4; j++) {
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
}
int subI = byte2int(sub, 4);
if (subI == -1) {
LOGGER.severe(Common.addTag("[Masking] the the returned <subI> is not valid: -1, please check!"));
return -1;
}
Float fNoise = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
noise.add(fNoise);
}
return 0;
}
private static int byte2int(int[] data, int number) {
if (data.length < 4) {
LOGGER.severe(Common.addTag("[Masking] the input argument <data> is not valid: length < 4, please check!"));
return -1;
}
switch (number) {
case 1:
return (int) data[0];
case 2:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
case 3:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
case 4:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
| (data[3] << 24 & 0xff000000);
default:
return 0;
}
}
}

View File

@ -1,82 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.List;
import java.util.logging.Logger;
public class Random {
/**
* random generate RNG algorithm name
*/
private static final Logger LOGGER = Logger.getLogger(Random.class.toString());
private static final String RNG_ALGORITHM = "SHA1PRNG";
private static final int RANDOM_LEN = 128 / 8;
public void getRandomBytes(byte[] secret) {
try {
SecureRandom secureRandom = SecureRandom.getInstance("SHA1PRNG");
secureRandom.nextBytes(secret);
} catch (NoSuchAlgorithmException e) {
e.printStackTrace();
}
}
public void randomAESCTR(List<Float> noise, int length, byte[] seed) throws Exception {
int intV = Integer.SIZE / 8;
int size = length * intV;
byte[] data = new byte[size];
for (int i = 0; i < size; i++) {
data[i] = 0;
}
byte[] ivec = new byte[RANDOM_LEN];
AESEncrypt aesEncrypt = new AESEncrypt(seed, ivec, "CTR");
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data);
for (int i = 0; i < length; i++) {
int[] sub = new int[intV];
for (int j = 0; j < 4; j++) {
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
}
int subI = byte2int(sub, 4);
Float f = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
noise.add(f);
}
}
public static int byte2int(int[] data, int n) {
switch (n) {
case 1:
return (int) data[0];
case 2:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
case 3:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
case 4:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
| (data[3] << 24 & 0xff000000);
default:
return 0;
}
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,23 +16,33 @@
package com.mindspore.flclient.cipher;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLCommunication;
import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import mindspore.schema.ClientShare;
import mindspore.schema.ResponseCode;
import mindspore.schema.ClientShare;
import mindspore.schema.ReconstructSecret;
import mindspore.schema.ResponseCode;
import mindspore.schema.SendReconstructSecret;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.LocalDateTime;
import java.util.Date;
import java.util.List;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
/**
* reconstruct secret request
*
* @since 2021-8-27
*/
public class ReconstructSecretReq {
private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString());
private FLCommunication flCommunication;
@ -41,36 +51,44 @@ public class ReconstructSecretReq {
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int retCode;
public String getNextRequestTime() {
return nextRequestTime;
}
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
public int getRetCode() {
return retCode;
}
/**
* reconstruct secret request
*/
public ReconstructSecretReq() {
flCommunication = FLCommunication.getInstance();
}
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList, List<String> u3ClientList, int iteration) {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
/**
* send secret shards to server
*
* @param decryptShareSecretsList secret shards list
* @param u3ClientList u3 client list
* @param iteration iter number
* @return request result
*/
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList,
List<String> u3ClientList, int iteration) {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
FlatBufferBuilder builder = new FlatBufferBuilder();
int desFlId = builder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString();
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = builder.createString(dateTime);
int shareSecretsSize = decryptShareSecretsList.size();
if (shareSecretsSize <= 0) {
LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please waite."));
LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please " +
"waite."));
return FLClientStatus.FAILED;
} else {
int[] decryptShareList = new int[shareSecretsSize];
for (int i = 0; i < shareSecretsSize; i++) {
DecryptShareSecrets decryptShareSecrets = decryptShareSecretsList.get(i);
if (decryptShareSecrets.getFlID() == null) {
LOGGER.severe(Common.addTag("[PairWiseMask] get remote flID failed!"));
return FLClientStatus.FAILED;
}
String srcFlId = decryptShareSecrets.getFlID();
byte[] share;
int index;
@ -86,31 +104,33 @@ public class ReconstructSecretReq {
int clientShare = ClientShare.createClientShare(builder, fbsSrcFlId, fbsShare, index);
decryptShareList[i] = clientShare;
}
int reconstructShareSecrets = mindspore.schema.SendReconstructSecret.createReconstructSecretSharesVector(builder, decryptShareList);
int reconstructSecretRoot = mindspore.schema.SendReconstructSecret.createSendReconstructSecret(builder, desFlId, reconstructShareSecrets, iteration, time);
int reconstructShareSecrets = SendReconstructSecret.createReconstructSecretSharesVector(builder,
decryptShareList);
int reconstructSecretRoot = SendReconstructSecret.createSendReconstructSecret(builder, desFlId,
reconstructShareSecrets, iteration, time);
builder.finish(reconstructSecretRoot);
byte[] msg = builder.sizedByteArray();
try {
byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
LOGGER.info(Common.addTag("[sendReconstructSecret] The cluster is in safemode, need wait some time and request again"));
if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[sendReconstructSecret] the server is not ready now, need wait some " +
"time and request again"));
Common.sleep(SLEEP_TIME);
nextRequestTime = "";
return FLClientStatus.RESTART;
}
ByteBuffer buffer = ByteBuffer.wrap(responseData);
mindspore.schema.ReconstructSecret reconstructSecretRsp = mindspore.schema.ReconstructSecret.getRootAsReconstructSecret(buffer);
FLClientStatus status = judgeSendReconstructSecrets(reconstructSecretRsp);
return status;
} catch (Exception e) {
ReconstructSecret reconstructSecretRsp = ReconstructSecret.getRootAsReconstructSecret(buffer);
return judgeSendReconstructSecrets(reconstructSecretRsp);
} catch (IOException ex) {
LOGGER.severe(Common.addTag("[PairWiseMask] un solved error code in reconstruct"));
e.printStackTrace();
ex.printStackTrace();
return FLClientStatus.FAILED;
}
}
}
public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) {
private FLClientStatus judgeSendReconstructSecrets(ReconstructSecret bufData) {
retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -122,7 +142,8 @@ public class ReconstructSecretReq {
LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success"));
return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request startFLJob again"));
LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART;
case (ResponseCode.RequestError):
@ -130,8 +151,36 @@ public class ReconstructSecretReq {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets"));
return FLClientStatus.FAILED;
default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is invalid: " + retCode));
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is " +
"invalid: " + retCode));
return FLClientStatus.FAILED;
}
}
/**
* get next request time
*
* @return next request time
*/
public String getNextRequestTime() {
return nextRequestTime;
}
/**
* set next request time
*
* @param nextRequestTime next request time
*/
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
/**
* get retCode
*
* @return retCode
*/
public int getRetCode() {
return retCode;
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,115 +22,158 @@ import java.math.BigInteger;
import java.util.Random;
import java.util.logging.Logger;
/**
* Define functions that for splitting secret and combining secret shards.
*
* @since 2021-06-30
*/
public class ShareSecrets {
private static final Logger LOGGER = Logger.getLogger(ShareSecrets.class.toString());
public final class SecretShare {
public SecretShare(final int num, final BigInteger share) {
this.num = num;
this.share = share;
}
private BigInteger prime;
private final int minNum;
private final int totalNum;
private final Random random;
public int getNum() {
return num;
/**
* Defines the constructor of the class ShareSecrets.
*
* @param minNum minimum number of fragments required to reconstruct a secret.
* @param totalNum total clients number.
*/
public ShareSecrets(final int minNum, final int totalNum) {
if (minNum <= 0) {
LOGGER.severe(Common.addTag("the argument <k> is not valid: <= 0, it should be > 0"));
throw new IllegalArgumentException();
}
public BigInteger getShare() {
return share;
if (totalNum <= 0) {
LOGGER.severe(Common.addTag("the argument <n> is not valid: <= 0, it should be > 0"));
throw new IllegalArgumentException();
}
@Override
public String toString() {
return "SecretShare [num=" + num + ", share=" + share + "]";
if (minNum > totalNum) {
LOGGER.severe(Common.addTag("the argument <k, n> is not valid: k > n, it should k <= n"));
throw new IllegalArgumentException();
}
private final int num;
private final BigInteger share;
this.minNum = minNum;
this.totalNum = totalNum;
random = Common.getSecureRandom();
}
public ShareSecrets(final int k, final int n) {
this.k = k;
this.n = n;
random = new Random();
}
public SecretShare[] split(final byte[] bytes, byte[] primeByte) {
/**
* Splits a secret into a specified number of secret fragments.
*
* @param bytes the secret need to be split.
* @param primeByte teh big prime number used to combine secret fragments.
* @return the secret fragments.
*/
public SecretShares[] split(final byte[] bytes, byte[] primeByte) {
if (bytes == null || bytes.length == 0) {
LOGGER.severe(Common.addTag("the input argument <bytes> is null"));
return new SecretShares[0];
}
if (primeByte == null || primeByte.length == 0) {
LOGGER.severe(Common.addTag("the input argument <primeByte> is null"));
return new SecretShares[0];
}
BigInteger secret = BaseUtil.byteArray2BigInteger(bytes);
final int modLength = secret.bitLength() + 1;
prime = BaseUtil.byteArray2BigInteger(primeByte);
final BigInteger[] coeff = new BigInteger[k - 1];
final BigInteger[] coefficient = new BigInteger[minNum - 1];
LOGGER.info(Common.addTag("Prime Number: " + prime));
for (int i = 0; i < k - 1; i++) {
coeff[i] = randomZp(prime);
LOGGER.info(Common.addTag("a" + (i + 1) + ": " + coeff[i]));
for (int i = 0; i < minNum - 1; i++) {
coefficient[i] = randomZp(prime);
}
final SecretShare[] shares = new SecretShare[n];
for (int i = 1; i <= n; i++) {
BigInteger accum = secret;
final SecretShares[] shares = new SecretShares[totalNum];
for (int i = 1; i <= totalNum; i++) {
BigInteger accumulate = secret;
for (int j = 1; j < k; j++) {
final BigInteger t1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime);
final BigInteger t2 = coeff[j - 1].multiply(t1).mod(prime);
for (int j = 1; j < minNum; j++) {
final BigInteger b1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime);
final BigInteger b2 = coefficient[j - 1].multiply(b1).mod(prime);
accum = accum.add(t2).mod(prime);
accumulate = accumulate.add(b2).mod(prime);
}
shares[i - 1] = new SecretShare(i, accum);
LOGGER.info(Common.addTag("Share " + shares[i - 1]));
shares[i - 1] = new SecretShares(i, accumulate);
}
return shares;
}
public BigInteger getPrime() {
return prime;
}
public BigInteger combine(final SecretShare[] shares, final byte[] primeByte) {
/**
* Combine secret fragments.
*
* @param shares the secret fragments.
* @param primeByte teh big prime number used to combine secret fragments.
* @return the secrets combined by secret fragments.
*/
public BigInteger combine(final SecretShares[] shares, final byte[] primeByte) {
if (shares == null || shares.length == 0) {
LOGGER.severe(Common.addTag("the input argument <shares> is null"));
return BigInteger.ZERO;
}
if (primeByte == null || primeByte.length == 0) {
LOGGER.severe(Common.addTag("the input argument <primeByte> is null"));
return BigInteger.ZERO;
}
BigInteger primeNum = BaseUtil.byteArray2BigInteger(primeByte);
BigInteger accum = BigInteger.ZERO;
for (int j = 0; j < k; j++) {
BigInteger accumulate = BigInteger.ZERO;
for (int j = 0; j < minNum; j++) {
BigInteger num = BigInteger.ONE;
BigInteger den = BigInteger.ONE;
BigInteger tmp;
for (int m = 0; m < k; m++) {
for (int m = 0; m < minNum; m++) {
if (j != m) {
num = num.multiply(BigInteger.valueOf(shares[m].getNum())).mod(primeNum);
tmp = BigInteger.valueOf(shares[j].getNum()).multiply(BigInteger.valueOf(-1));
tmp = BigInteger.valueOf(shares[m].getNum()).add(tmp).mod(primeNum);
num = num.multiply(BigInteger.valueOf(shares[m].getNumber())).mod(primeNum);
tmp = BigInteger.valueOf(shares[j].getNumber()).multiply(BigInteger.valueOf(-1));
tmp = BigInteger.valueOf(shares[m].getNumber()).add(tmp).mod(primeNum);
den = den.multiply(tmp).mod(primeNum);
}
}
final BigInteger value = shares[j].getShare();
final BigInteger value = shares[j].getShares();
tmp = den.modInverse(primeNum);
tmp = tmp.multiply(num).mod(primeNum);
tmp = tmp.multiply(value).mod(primeNum);
accum = accum.add(tmp).mod(primeNum);
LOGGER.info(Common.addTag("value: " + value + ", tmp: " + tmp + ", accum: " + accum));
accumulate = accumulate.add(tmp).mod(primeNum);
}
LOGGER.info(Common.addTag("The secret is: " + accum));
return accum;
return accumulate;
}
private BigInteger randomZp(final BigInteger p) {
private BigInteger randomZp(final BigInteger num) {
while (true) {
final BigInteger r = new BigInteger(p.bitLength(), random);
if (r.compareTo(BigInteger.ZERO) > 0 && r.compareTo(p) < 0) {
return r;
final BigInteger rand = new BigInteger(num.bitLength(), random);
if (rand.compareTo(BigInteger.ZERO) > 0 && rand.compareTo(num) < 0) {
return rand;
}
}
}
private BigInteger prime;
private final int k;
private final int n;
private final Random random;
private final int SECRET_MAX_LEN = 32;
/**
* Define the structure for store secret fragments.
*/
public final class SecretShares {
private final int number;
private final BigInteger share;
public SecretShares(final int number, final BigInteger share) {
this.number = number;
this.share = share;
}
public int getNumber() {
return number;
}
public BigInteger getShares() {
return share;
}
@Override
public String toString() {
return "SecretShares [number=" + number + ", share=" + share + "]";
}
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,33 +16,114 @@
package com.mindspore.flclient.cipher.struct;
import com.mindspore.flclient.Common;
import java.util.logging.Logger;
/**
* public key class of secure aggregation
*
* @since 2021-8-27
*/
public class ClientPublicKey {
private static final Logger LOGGER = Logger.getLogger(ClientPublicKey.class.toString());
private String flID;
private NewArray<byte[]> cPK;
private NewArray<byte[]> sPk;
private NewArray<byte[]> pwIv;
private NewArray<byte[]> pwSalt;
/**
* get client's flID
*
* @return flID of this client
*/
public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of <flID> is null, please set it before use"));
throw new IllegalArgumentException();
}
return flID;
}
/**
* set client's flID
*
* @param flID hash value used for identify client
*/
public void setFlID(String flID) {
this.flID = flID;
}
/**
* get CPK of secure aggregation
*
* @return CPK of secure aggregation
*/
public NewArray<byte[]> getCPK() {
return cPK;
}
/**
* set CPK of secure aggregation
*
* @param cPK public key used for encryption
*/
public void setCPK(NewArray<byte[]> cPK) {
this.cPK = cPK;
}
/**
* get SPK of secure aggregation
*
* @return SPK of secure aggregation
*/
public NewArray<byte[]> getSPK() {
return sPk;
}
/**
* set SPK of secure aggregation
*
* @param sPk public key used for encryption
*/
public void setSPK(NewArray<byte[]> sPk) {
this.sPk = sPk;
}
/**
* get the IV value used for pairwise encrypt
*
* @return the IV value used for pairwise encrypt
*/
public NewArray<byte[]> getPwIv() {
return pwIv;
}
/**
* set the IV value used for pairwise encrypt
*
* @param pwIv IV value used for pairwise encrypt
*/
public void setPwIv(NewArray<byte[]> pwIv) {
this.pwIv = pwIv;
}
/**
* get salt value for secure aggregation
*
* @return salt value for secure aggregation
*/
public NewArray<byte[]> getPwSalt() {
return pwSalt;
}
/**
* set salt value for secure aggregation
*
* @param pwSalt salt value for secure aggregation
*/
public void setPwSalt(NewArray<byte[]> pwSalt) {
this.pwSalt = pwSalt;
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,49 +16,114 @@
package com.mindspore.flclient.cipher.struct;
import com.mindspore.flclient.Common;
import java.util.logging.Logger;
/**
* class used for set and get decryption shards
*
* @since 2021-8-27
*/
public class DecryptShareSecrets {
private static final Logger LOGGER = Logger.getLogger(DecryptShareSecrets.class.toString());
private String flID;
private NewArray<byte[]> sSkVu;
private NewArray<byte[]> bVu;
private int sIndex;
private int indexB;
/**
* get flID of client
*
* @return flID of this client
*/
public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
"use"));
throw new IllegalArgumentException();
}
return flID;
}
/**
* set flID for this client
*
* @param flID hash value used for identify client
*/
public void setFlID(String flID) {
this.flID = flID;
}
/**
* get secret key shards
*
* @return secret key shards
*/
public NewArray<byte[]> getSSkVu() {
return sSkVu;
}
/**
* set secret key shards
*
* @param sSkVu secret key shards
*/
public void setSSkVu(NewArray<byte[]> sSkVu) {
this.sSkVu = sSkVu;
}
/**
* get bu shards
*
* @return bu shards
*/
public NewArray<byte[]> getBVu() {
return bVu;
}
/**
* set bu shards
*
* @param bVu bu shards used for secure aggregation
*/
public void setBVu(NewArray<byte[]> bVu) {
this.bVu = bVu;
}
/**
* get index of secret shards
*
* @return index of secret shards
*/
public int getSIndex() {
return sIndex;
}
/**
* set index of secret shards
*
* @param sIndex index of secret shards
*/
public void setSIndex(int sIndex) {
this.sIndex = sIndex;
}
/**
* get index of bu shards
*
* @return index of bu shards
*/
public int getIndexB() {
return indexB;
}
/**
* set index of bu shards
*
* @param indexB index of bu shards
*/
public void setIndexB(int indexB) {
this.indexB = indexB;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,22 +16,57 @@
package com.mindspore.flclient.cipher.struct;
import com.mindspore.flclient.Common;
import java.util.logging.Logger;
/**
* class used for encrypt shares of secret
*
* @since 2021-8-27
*/
public class EncryptShare {
private static final Logger LOGGER = Logger.getLogger(DecryptShareSecrets.class.toString());
private String flID;
private NewArray<byte[]> share;
/**
* get client's flID
*
* @return flID of this client
*/
public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
"use"));
throw new IllegalArgumentException();
}
return flID;
}
/**
* set client's flID
*
* @param flID hash value used for identify client
*/
public void setFlID(String flID) {
this.flID = flID;
}
/**
* get secret share
*
* @return secret share
*/
public NewArray<byte[]> getShare() {
return share;
}
/**
* set secret share
*
* @param share secret share
*/
public void setShare(NewArray<byte[]> share) {
this.share = share;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,24 +16,50 @@
package com.mindspore.flclient.cipher.struct;
/**
* class used define new array type
*
* @param <T> an array
*
* @since 2021-8-27
*/
public class NewArray<T> {
private int size;
private T array;
/**
* get array size
*
* @return array size
*/
public int getSize() {
return size;
}
/**
* set array size
*
* @param size array size
*/
public void setSize(int size) {
this.size = size;
}
/**
* get array
*
* @return an array
*/
public T getArray() {
return array;
}
/**
* set array
*
* @param array input
*/
public void setArray(T array) {
this.array = array;
}
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,31 +16,75 @@
package com.mindspore.flclient.cipher.struct;
import com.mindspore.flclient.Common;
import java.util.logging.Logger;
/**
* share secret class
*
* @since 2021-8-27
*/
public class ShareSecret {
private static final Logger LOGGER = Logger.getLogger(ShareSecret.class.toString());
private String flID;
private NewArray<byte[]> share;
private int index;
/**
* get client's flID
*
* @return flID of this client
*/
public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[ShareSecret] the parameter of <flID> is null, please set it before use"));
throw new IllegalArgumentException();
}
return flID;
}
/**
* set flID for this client
*
* @param flID hash value used for identify client
*/
public void setFlID(String flID) {
this.flID = flID;
}
/**
* get secret share
*
* @return secret share
*/
public NewArray<byte[]> getShare() {
return share;
}
/**
* set secret share
*
* @param share secret shares
*/
public void setShare(NewArray<byte[]> share) {
this.share = share;
}
/**
* get secret index
*
* @return secret index
*/
public int getIndex() {
return index;
}
/**
* set secret index
*
* @param index secret index
*/
public void setIndex(int index) {
this.index = index;
}

View File

@ -31,6 +31,8 @@ table ClientPublicKeys {
fl_id:string;
c_pk:[ubyte];
s_pk: [ubyte];
pw_iv: [ubyte];
pw_salt: [ubyte];
}
table ClientShare {
@ -45,6 +47,9 @@ table RequestExchangeKeys{
s_pk:[ubyte];
iteration:int;
timestamp:string;
ind_iv:[ubyte];
pw_iv:[ubyte];
pw_salt:[ubyte];
}
table ResponseExchangeKeys{