Fix security problems and code-check problems for federated's secure aggregation

fix security check problems for flclient
This commit is contained in:
jin-xiulang 2021-09-10 10:30:49 +08:00
parent 0abff9ad65
commit 7d9dd343f3
68 changed files with 4321 additions and 2225 deletions

View File

@ -39,7 +39,7 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc") list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/key_agreement.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/key_agreement.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/random.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/masking.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/secret_sharing.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/secret_sharing.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_init.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_init.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_keys.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_keys.cc")

View File

@ -22,26 +22,28 @@
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_initial_client_cnt, size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
size_t cipher_exchange_secrets_cnt, size_t cipher_share_secrets_cnt,
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt, size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
size_t cipher_reconstruct_secrets_up_cnt) { size_t cipher_reconstruct_secrets_up_cnt) {
MS_LOG(INFO) << "CipherInit::Init START"; MS_LOG(INFO) << "CipherInit::Init START";
int return_num = 0; if (memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN) != 0) {
return_num = memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN); MS_LOG(ERROR) << "CipherInit::memory copy failed.";
if (return_num != 0) {
return false; return false;
} }
publicparam_.g = param.g; publicparam_.g = param.g;
publicparam_.t = param.t; publicparam_.t = param.t;
secrets_minnums_ = param.t; secrets_minnums_ = param.t;
client_num_need_ = cipher_initial_client_cnt;
featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float); featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float);
share_clients_num_need_ = cipher_share_secrets_cnt;
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1; exchange_key_threshold = cipher_exchange_keys_cnt;
get_model_num_need_ = cipher_get_clientlist_cnt; get_key_threshold = cipher_get_keys_cnt;
share_secrets_threshold = cipher_share_secrets_cnt;
get_secrets_threshold = cipher_get_secrets_cnt;
client_list_threshold = cipher_get_clientlist_cnt;
reconstruct_secrets_threshold = cipher_reconstruct_secrets_up_cnt;
time_out_mutex_ = time_out_mutex; time_out_mutex_ = time_out_mutex;
publicparam_.dp_eps = param.dp_eps; publicparam_.dp_eps = param.dp_eps;
publicparam_.dp_delta = param.dp_delta; publicparam_.dp_delta = param.dp_delta;
@ -62,10 +64,12 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
MS_LOG(ERROR) << "Cipher Param Update is invalid."; MS_LOG(ERROR) << "Cipher Param Update is invalid.";
return false; return false;
} }
MS_LOG(INFO) << " CipherInit client_num_need_ : " << client_num_need_; MS_LOG(INFO) << " CipherInit exchange_key_threshold : " << exchange_key_threshold;
MS_LOG(INFO) << " CipherInit share_clients_num_need_ : " << share_clients_num_need_; MS_LOG(INFO) << " CipherInit get_key_threshold : " << get_key_threshold;
MS_LOG(INFO) << " CipherInit reconstruct_clients_num_need_ : " << reconstruct_clients_num_need_; MS_LOG(INFO) << " CipherInit share_secrets_threshold : " << share_secrets_threshold;
MS_LOG(INFO) << " CipherInit get_model_num_need_ : " << get_model_num_need_; MS_LOG(INFO) << " CipherInit get_secrets_threshold : " << get_secrets_threshold;
MS_LOG(INFO) << " CipherInit client_list_threshold : " << client_list_threshold;
MS_LOG(INFO) << " CipherInit reconstruct_secrets_threshold : " << reconstruct_secrets_threshold;
MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_; MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_;
if (!Check_Parames()) { if (!Check_Parames()) {
MS_LOG(ERROR) << "Cipher parameters are illegal."; MS_LOG(ERROR) << "Cipher parameters are illegal.";
@ -82,11 +86,10 @@ bool CipherInit::Check_Parames() {
return false; return false;
} }
if (share_clients_num_need_ < reconstruct_clients_num_need_) { if (share_secrets_threshold < reconstruct_secrets_threshold) {
MS_LOG(ERROR) MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger "
<< "reconstruct_clients_num_need (which is reconstruct_secrets_threshold + 1) should not be larger " "than share_secrets_threshold, but got they are:"
"than share_clients_num_need (which is start_fl_job_threshold*share_secrets_ratio), but got they are:" << reconstruct_secrets_threshold << ", " << share_secrets_threshold;
<< reconstruct_clients_num_need_ << ", " << share_clients_num_need_;
return false; return false;
} }

View File

@ -29,17 +29,6 @@
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
template <typename T1>
bool CreateArray(std::vector<T1> *newData, const flatbuffers::Vector<T1> &fbs_arr) {
size_t size = newData->size();
size_t size_fbs_arr = fbs_arr.size();
if (size != size_fbs_arr) return false;
for (size_t i = 0; i < size; ++i) {
newData->at(i) = fbs_arr.Get(i);
}
return true;
}
// Initialization of secure aggregation. // Initialization of secure aggregation.
class CipherInit { class CipherInit {
public: public:
@ -49,9 +38,10 @@ class CipherInit {
} }
// Initialize the parameters of the secure aggregation. // Initialize the parameters of the secure aggregation.
bool Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_initial_client_cnt, bool Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
size_t cipher_exchange_secrets_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_clientlist_cnt, size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
size_t cipher_reconstruct_secrets_down_cnt, size_t cipher_reconstruct_secrets_up_cnt); size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
size_t cipher_reconstruct_secrets_up_cnt);
// Check whether the parameters are valid. // Check whether the parameters are valid.
bool Check_Parames(); bool Check_Parames();
@ -59,10 +49,12 @@ class CipherInit {
// Get public params. which is given to start fl job thread. // Get public params. which is given to start fl job thread.
CipherPublicPara *GetPublicParams() { return &publicparam_; } CipherPublicPara *GetPublicParams() { return &publicparam_; }
size_t share_clients_num_need_; // the minimum number of clients to share secrets. size_t share_secrets_threshold; // the minimum number of clients to share secret fragments.
size_t reconstruct_clients_num_need_; // the minimum number of clients to reconstruct secret mask. size_t get_secrets_threshold; // the minimum number of clients to get secret fragments.
size_t client_num_need_; // the minimum number of clients to update model. size_t reconstruct_secrets_threshold; // the minimum number of clients to reconstruct secret mask.
size_t get_model_num_need_; // the minimum number of clients to get model. size_t exchange_key_threshold; // the minimum number of clients to send public keys.
size_t get_key_threshold; // the minimum number of clients to get public keys.
size_t client_list_threshold; // the minimum number of clients to get update model client list.
size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask. size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask.
size_t featuremap_; // the size of data to deal. size_t featuremap_; // the size of data to deal.

View File

@ -21,203 +21,170 @@ namespace mindspore {
namespace armour { namespace armour {
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time, bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req, const schema::GetExchangeKeys *get_exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &get_exchange_keys_resp_builder) { const std::shared_ptr<fl::server::FBBuilder> &fbb) {
MS_LOG(INFO) << "CipherMgr::GetKeys START"; MS_LOG(INFO) << "CipherMgr::GetKeys START";
if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) { if (get_exchange_keys_req == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr."; MS_LOG(ERROR) << "Request is nullptr";
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false); BuildGetKeysRsp(fbb, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false);
return false; return false;
} }
// get clientlist from memory server. // get clientlist from memory server.
std::vector<std::string> clients; std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &clients); size_t cur_exchange_clients_num = client_public_keys.size();
size_t cur_clients_num = clients.size();
std::string fl_id = get_exchange_keys_req->fl_id()->str(); std::string fl_id = get_exchange_keys_req->fl_id()->str();
if (find(clients.begin(), clients.end(), fl_id) == clients.end()) { if (cur_exchange_clients_num < cipher_init_->exchange_key_threshold) {
MS_LOG(INFO) << "The fl_id is not in clients."; MS_LOG(INFO) << "The server is not ready yet: cur_exchangekey_clients_num < exchange_key_threshold";
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false); MS_LOG(INFO) << "cur_exchangekey_clients_num : " << cur_exchange_clients_num
<< ", exchange_key_threshold : " << cipher_init_->exchange_key_threshold;
BuildGetKeysRsp(fbb, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
return false; return false;
} }
if (cur_clients_num < cipher_init_->client_num_need_) {
MS_LOG(INFO) << "The server is not ready yet: cur_clients_num < client_num_need"; if (client_public_keys.find(fl_id) == client_public_keys.end()) {
MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << ", client_num_need : " << cipher_init_->client_num_need_; MS_LOG(INFO) << "Get keys: the fl_id: " << fl_id << "is not in exchange keys clients.";
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false); BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
return false;
}
bool ret = cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetKeysClientList, fl_id);
if (!ret) {
MS_LOG(ERROR) << "update get keys clients failed";
BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, cur_iterator, next_req_time, false);
return false; return false;
} }
MS_LOG(INFO) << "GetKeys client list: "; MS_LOG(INFO) << "GetKeys client list: ";
for (size_t i = 0; i < clients.size(); i++) { BuildGetKeysRsp(fbb, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
MS_LOG(INFO) << "fl_id: " << clients[i]; return true;
} }
bool flag =
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
return flag;
} // namespace armour
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time, bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req, const schema::RequestExchangeKeys *exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder) { const std::shared_ptr<fl::server::FBBuilder> &fbb) {
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START"; MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
// step 0: judge if the input param is legal. // step 0: judge if the input param is legal.
if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) { if (exchange_keys_req == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr."; std::string reason = "Request is nullptr";
std::string reason = "Request is nullptr or Response builder is nullptr."; MS_LOG(ERROR) << reason;
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time, BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason, next_req_time, cur_iterator);
cur_iterator); return false;
}
std::string fl_id = exchange_keys_req->fl_id()->str();
mindspore::fl::PBMetadata device_metas =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxDeviceMetas);
mindspore::fl::FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
MS_LOG(INFO) << "exchange key for fl id " << fl_id;
if (fl_id_to_meta.fl_id_to_meta().count(fl_id) == 0) {
std::string reason = "devices_meta for " + fl_id + " is not set. Please retry later.";
BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator);
MS_LOG(ERROR) << reason;
return false; return false;
} }
// step 1: get clientlist and client keys from memory server. // step 1: get clientlist and client keys from memory server.
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys; std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
std::vector<std::string> client_list; std::vector<std::string> client_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list); cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys); cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
// step2: process new item data. and update new item data to memory server. // step2: process new item data. and update new item data to memory server.
size_t cur_clients_num = client_list.size(); size_t cur_clients_num = client_list.size();
size_t cur_clients_has_keys_num = record_public_keys.size(); size_t cur_clients_has_keys_num = client_public_keys.size();
if (cur_clients_num != cur_clients_has_keys_num) { if (cur_clients_num != cur_clients_has_keys_num) {
std::string reason = "client num and keys num are not equal."; std::string reason = "client num and keys num are not equal.";
MS_LOG(ERROR) << reason; MS_LOG(WARNING) << reason;
MS_LOG(ERROR) << "cur_clients_num is " << cur_clients_num << ". cur_clients_has_keys_num is " MS_LOG(WARNING) << "cur_clients_num is " << cur_clients_num << ". cur_clients_has_keys_num is "
<< cur_clients_has_keys_num; << cur_clients_has_keys_num;
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, reason, next_req_time,
cur_iterator);
return false;
} }
MS_LOG(WARNING) << "exchange_key_threshold " << cipher_init_->exchange_key_threshold << ". cur_clients_num "
<< cur_clients_num << ". cur_clients_keys_num " << cur_clients_has_keys_num;
MS_LOG(INFO) << "client_num_need_ " << cipher_init_->client_num_need_ << ". cur_clients_num " << cur_clients_num; if (client_public_keys.find(fl_id) != client_public_keys.end()) { // the client already exists, return false.
std::string fl_id = exchange_keys_req->fl_id()->str(); MS_LOG(ERROR) << "The server has received the request, please do not request again.";
if (cur_clients_num >= cipher_init_->client_num_need_) { // the client num is enough, return false. BuildExchangeKeysRsp(fbb, schema::ResponseCode_SUCCEED,
MS_LOG(ERROR) << "The server has received enough requests and refuse this request.";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime,
"The server has received enough requests and refuse this request.", next_req_time,
cur_iterator);
return false;
}
if (record_public_keys.find(fl_id) != record_public_keys.end()) { // the client already exists, return false.
MS_LOG(INFO) << "The server has received the request, please do not request again.";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
"The server has received the request, please do not request again.", next_req_time, "The server has received the request, please do not request again.", next_req_time,
cur_iterator); cur_iterator);
return false; return false;
} }
// Gets the members of the deserialized data exchange_keys_req
auto fbs_cpk = exchange_keys_req->c_pk();
size_t cpk_len = fbs_cpk->size();
auto fbs_spk = exchange_keys_req->s_pk();
size_t spk_len = fbs_spk->size();
// transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
std::vector<std::vector<unsigned char>> cur_public_key;
std::vector<unsigned char> cpk(cpk_len);
std::vector<unsigned char> spk(spk_len);
bool ret_create_code_cpk = CreateArray<unsigned char>(&cpk, *fbs_cpk);
bool ret_create_code_spk = CreateArray<unsigned char>(&spk, *fbs_spk);
if (!(ret_create_code_cpk && ret_create_code_spk)) {
MS_LOG(ERROR) << "create cur_public_key failed";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, "update key or client failed",
next_req_time, cur_iterator);
return false;
}
cur_public_key.push_back(cpk);
cur_public_key.push_back(spk);
bool retcode_key = bool retcode_key =
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, fl_id, cur_public_key); cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
bool retcode_client = bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id); cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
if (retcode_key && retcode_client) { if (retcode_key && retcode_client) {
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success"; MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, BuildExchangeKeysRsp(fbb, schema::ResponseCode_SUCCEED, "Success, but the server is not ready yet.", next_req_time,
"Success, but the server is not ready yet.", next_req_time, cur_iterator); cur_iterator);
return true; return true;
} else { } else {
MS_LOG(ERROR) << "update key or client failed"; MS_LOG(ERROR) << "update key or client failed";
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, "update key or client failed", BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "update key or client failed", next_req_time,
next_req_time, cur_iterator); cur_iterator);
return false; return false;
} }
} }
void CipherKeys::BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder, void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const schema::ResponseCode retcode, const std::string &reason, const std::string &reason, const std::string &next_req_time,
const std::string &next_req_time, const int iteration) { const int iteration) {
auto rsp_reason = exchange_keys_resp_builder->CreateString(reason); auto rsp_reason = fbb->CreateString(reason);
auto rsp_next_req_time = exchange_keys_resp_builder->CreateString(next_req_time); auto rsp_next_req_time = fbb->CreateString(next_req_time);
schema::ResponseExchangeKeysBuilder rsp_builder(*(exchange_keys_resp_builder.get()));
schema::ResponseExchangeKeysBuilder rsp_builder(*(fbb.get()));
rsp_builder.add_retcode(retcode); rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason); rsp_builder.add_reason(rsp_reason);
rsp_builder.add_next_req_time(rsp_next_req_time); rsp_builder.add_next_req_time(rsp_next_req_time);
rsp_builder.add_iteration(iteration); rsp_builder.add_iteration(iteration);
auto rsp_exchange_keys = rsp_builder.Finish(); auto rsp_exchange_keys = rsp_builder.Finish();
exchange_keys_resp_builder->Finish(rsp_exchange_keys); fbb->Finish(rsp_exchange_keys);
return; return;
} }
bool CipherKeys::BuildGetKeys(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode, void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good) { const int iteration, const std::string &next_req_time, bool is_good) {
bool flag = true; if (!is_good) {
if (is_good) {
// convert client keys to standard keys list.
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
MS_LOG(INFO) << "Get Keys: ";
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
if (record_public_keys.size() < cipher_init_->client_num_need_) {
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
<< "clients num: " << cipher_init_->client_num_need_;
flag = false;
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
} else {
for (auto iter = record_public_keys.begin(); iter != record_public_keys.end(); ++iter) {
// read (fl_id, c_pk, s_pk) from the map: record_public_keys_
std::string fl_id = iter->first;
MS_LOG(INFO) << "fl_id : " << fl_id;
// To serialize the members to a new TableClientPublicKeys
auto fbs_fl_id = fbb->CreateString(fl_id);
auto fbs_c_pk = fbb->CreateVector(iter->second[0].data(), iter->second[0].size());
auto fbs_s_pk = fbb->CreateVector(iter->second[1].data(), iter->second[1].size());
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk);
public_keys_list.push_back(cur_public_key);
}
auto remote_publickeys = fbb->CreateVector(public_keys_list);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_remote_publickeys(remote_publickeys);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
}
} else {
auto fbs_next_req_time = fbb->CreateString(next_req_time); auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get())); schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode); rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration); rsp_buider.add_iteration(iteration);
rsp_buider.add_next_req_time(fbs_next_req_time); rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish(); auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys); fbb->Finish(rsp_get_keys);
return;
} }
return flag; const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientsKeys);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
std::string fl_id = iter->first;
fl::KeysPb keys_pb = iter->second;
auto fbs_fl_id = fbb->CreateString(fl_id);
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size());
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
auto fbs_pw_iv = fbb->CreateVector(pw_iv.data(), pw_iv.size());
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
auto fbs_pw_salt = fbb->CreateVector(pw_salt.data(), pw_salt.size());
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
public_keys_list.push_back(cur_public_key);
}
auto remote_publickeys = fbb->CreateVector(public_keys_list);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
rsp_buider.add_retcode(retcode);
rsp_buider.add_iteration(iteration);
rsp_buider.add_remote_publickeys(remote_publickeys);
rsp_buider.add_next_req_time(fbs_next_req_time);
auto rsp_get_keys = rsp_buider.Finish();
fbb->Finish(rsp_get_keys);
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
return;
} }
void CipherKeys::ClearKeys() { void CipherKeys::ClearKeys() {

View File

@ -44,21 +44,19 @@ class CipherKeys {
// handle the client's request of get keys. // handle the client's request of get keys.
bool GetKeys(const int cur_iterator, const std::string &next_req_time, bool GetKeys(const int cur_iterator, const std::string &next_req_time,
const schema::GetExchangeKeys *get_exchange_keys_req, const schema::GetExchangeKeys *get_exchange_keys_req, const std::shared_ptr<fl::server::FBBuilder> &fbb);
const std::shared_ptr<fl::server::FBBuilder> &get_exchange_keys_resp_builder);
// handle the client's request of exchange keys. // handle the client's request of exchange keys.
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time, bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
const schema::RequestExchangeKeys *exchange_keys_req, const schema::RequestExchangeKeys *exchange_keys_req,
const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder); const std::shared_ptr<fl::server::FBBuilder> &fbb);
// build response code of get keys. // build response code of get keys.
bool BuildGetKeys(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode, void BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const int iteration, const std::string &next_req_time, bool is_good); const int iteration, const std::string &next_req_time, bool is_good);
// build response code of exchange keys. // build response code of exchange keys.
void BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder, void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
const schema::ResponseCode retcode, const std::string &reason, const std::string &reason, const std::string &next_req_time, const int iteration);
const std::string &next_req_time, const int iteration);
// clear the shared memory. // clear the shared memory.
void ClearKeys(); void ClearKeys();

View File

@ -51,7 +51,7 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect
} }
void CipherMetaStorage::GetClientKeysFromServer( void CipherMetaStorage::GetClientKeysFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) { const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) {
const fl::PBMetadata &clients_keys_pb_out = const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys(); const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
@ -60,12 +60,33 @@ void CipherMetaStorage::GetClientKeysFromServer(
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i); // const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
std::string fl_id = iter->first; std::string fl_id = iter->first;
fl::KeysPb keys_pb = iter->second; fl::KeysPb keys_pb = iter->second;
std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end()); std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end()); std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
std::vector<std::vector<unsigned char>> cur_keys; std::vector<std::vector<uint8_t>> cur_keys;
cur_keys.push_back(cpk); cur_keys.push_back(cpk);
cur_keys.push_back(spk); cur_keys.push_back(spk);
clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<unsigned char>>>(fl_id, cur_keys)); clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
}
}
void CipherMetaStorage::GetClientIVsFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) {
const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
std::string fl_id = iter->first;
fl::KeysPb keys_pb = iter->second;
std::vector<uint8_t> ind_iv(keys_pb.ind_iv().begin(), keys_pb.ind_iv().end());
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
std::vector<std::vector<uint8_t>> cur_ivs;
cur_ivs.push_back(ind_iv);
cur_ivs.push_back(pw_iv);
cur_ivs.push_back(pw_salt);
clients_ivs_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs));
} }
} }
@ -73,9 +94,18 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
const fl::PBMetadata &clients_noises_pb_out = const fl::PBMetadata &clients_noises_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises(); const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
int count = 0;
int count_thld = 100;
while (clients_noises_pb.has_one_client_noises() == false) { while (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "GetClientNoisesFromServer NULL."; int register_time = 500;
std::this_thread::sleep_for(std::chrono::milliseconds(50)); std::this_thread::sleep_for(std::chrono::milliseconds(register_time));
count++;
if (count >= count_thld) break;
}
MS_LOG(INFO) << "GetClientNoisesFromServer Count: " << count;
if (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(ERROR) << "GetClientNoisesFromServer NULL.";
return false;
} }
cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(), cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(),
clients_noises_pb.one_client_noises().noise().end()); clients_noises_pb.one_client_noises().noise().end());
@ -98,14 +128,14 @@ bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char
} }
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) { bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
bool retcode = true;
fl::FLId fl_id_pb; fl::FLId fl_id_pb;
fl_id_pb.set_fl_id(fl_id); fl_id_pb.set_fl_id(fl_id);
fl::PBMetadata client_pb; fl::PBMetadata client_pb;
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb); client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb); bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
return retcode; return retcode;
} }
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) { void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
MS_LOG(INFO) << "register prime: " << prime; MS_LOG(INFO) << "register prime: " << prime;
fl::Prime prime_id_pb; fl::Prime prime_id_pb;
@ -117,9 +147,9 @@ void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &
} }
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id, bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
const std::vector<std::vector<unsigned char>> &cur_public_key) { const std::vector<std::vector<uint8_t>> &cur_public_key) {
bool retcode = true; size_t correct_size = 2;
if (cur_public_key.size() < 2) { if (cur_public_key.size() < correct_size) {
MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size(); MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
return false; return false;
} }
@ -132,7 +162,73 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys); pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
fl::PBMetadata client_and_keys_pb; fl::PBMetadata client_and_keys_pb;
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb); client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb); bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
return retcode;
}
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
const schema::RequestExchangeKeys *exchange_keys_req) {
std::string fl_id = exchange_keys_req->fl_id()->str();
auto fbs_cpk = exchange_keys_req->c_pk();
auto fbs_spk = exchange_keys_req->s_pk();
if (fbs_cpk == nullptr || fbs_spk == nullptr) {
MS_LOG(ERROR) << "public key from exchange_keys_req is null";
return false;
}
size_t spk_len = fbs_spk->size();
size_t cpk_len = fbs_cpk->size();
// transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
std::vector<std::vector<uint8_t>> cur_public_key;
std::vector<uint8_t> cpk(cpk_len);
std::vector<uint8_t> spk(spk_len);
bool ret_create_code_cpk = CreateArray<uint8_t>(&cpk, *fbs_cpk);
bool ret_create_code_spk = CreateArray<uint8_t>(&spk, *fbs_spk);
if (!(ret_create_code_cpk && ret_create_code_spk)) {
MS_LOG(ERROR) << "create array for public keys failed";
return false;
}
cur_public_key.push_back(cpk);
cur_public_key.push_back(spk);
auto fbs_ind_iv = exchange_keys_req->ind_iv();
std::vector<char> ind_iv;
if (fbs_ind_iv == nullptr) {
MS_LOG(WARNING) << "ind_iv in exchange_keys_req is nullptr";
} else {
ind_iv.assign(fbs_ind_iv->begin(), fbs_ind_iv->end());
}
auto fbs_pw_iv = exchange_keys_req->pw_iv();
std::vector<char> pw_iv;
if (fbs_pw_iv == nullptr) {
MS_LOG(WARNING) << "pw_iv in exchange_keys_req is nullptr";
} else {
pw_iv.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
}
auto fbs_pw_salt = exchange_keys_req->pw_salt();
std::vector<char> pw_salt;
if (fbs_pw_salt == nullptr) {
MS_LOG(WARNING) << "pw_salt in exchange_keys_req is nullptr";
} else {
pw_salt.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
}
// update new item to memory server.
fl::KeysPb keys;
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
keys.set_ind_iv(ind_iv.data(), ind_iv.size());
keys.set_pw_iv(pw_iv.data(), pw_iv.size());
keys.set_pw_salt(pw_salt.data(), pw_salt.size());
fl::PairClientKeys pair_client_keys_pb;
pair_client_keys_pb.set_fl_id(fl_id);
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
fl::PBMetadata client_and_keys_pb;
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
return retcode; return retcode;
} }
@ -142,13 +238,13 @@ bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const s
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()}; *noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
fl::PBMetadata client_noises_pb; fl::PBMetadata client_noises_pb;
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb); client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
return fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb); bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
return ret;
} }
bool CipherMetaStorage::UpdateClientShareToServer( bool CipherMetaStorage::UpdateClientShareToServer(
const char *list_name, const std::string &fl_id, const char *list_name, const std::string &fl_id,
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) { const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
bool retcode = true;
int size_shares = shares->size(); int size_shares = shares->size();
fl::SharesPb shares_pb; fl::SharesPb shares_pb;
for (int index = 0; index < size_shares; ++index) { for (int index = 0; index < size_shares; ++index) {
@ -166,14 +262,17 @@ bool CipherMetaStorage::UpdateClientShareToServer(
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb); pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
fl::PBMetadata client_and_shares_pb; fl::PBMetadata client_and_shares_pb;
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb); client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb); bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
return retcode; return retcode;
} }
void CipherMetaStorage::RegisterClass() { void CipherMetaStorage::RegisterClass() {
fl::PBMetadata exchange_kyes_client_list; fl::PBMetadata exchange_keys_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList, fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
exchange_kyes_client_list); exchange_keys_client_list);
fl::PBMetadata get_keys_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetKeysClientList,
get_keys_client_list);
fl::PBMetadata clients_keys; fl::PBMetadata clients_keys;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys); fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
fl::PBMetadata reconstruct_client_list; fl::PBMetadata reconstruct_client_list;
@ -185,9 +284,15 @@ void CipherMetaStorage::RegisterClass() {
fl::PBMetadata share_secretes_client_list; fl::PBMetadata share_secretes_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList, fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
share_secretes_client_list); share_secretes_client_list);
fl::PBMetadata get_secretes_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetSecretsClientList,
get_secretes_client_list);
fl::PBMetadata clients_encrypt_shares; fl::PBMetadata clients_encrypt_shares;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares, fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
clients_encrypt_shares); clients_encrypt_shares);
fl::PBMetadata get_update_clients_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetUpdateModelClientList,
get_update_clients_list);
} }
} // namespace armour } // namespace armour
} // namespace mindspore } // namespace mindspore

View File

@ -31,23 +31,39 @@
#include "fl/server/distributed_metadata_store.h" #include "fl/server/distributed_metadata_store.h"
#include "fl/server/common.h" #include "fl/server/common.h"
#define IND_IV_INDEX 0
#define PW_IV_INDEX 1
#define PW_SALT_INDEX 2
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
template <typename T1>
bool CreateArray(std::vector<T1> *newData, const flatbuffers::Vector<T1> &fbs_arr) {
if (newData == nullptr) return false;
size_t size = newData->size();
size_t size_fbs_arr = fbs_arr.size();
if (size != size_fbs_arr) return false;
for (size_t i = 0; i < size; ++i) {
newData->at(i) = fbs_arr.Get(i);
}
return true;
}
constexpr int SHARE_MAX_SIZE = 256; constexpr int SHARE_MAX_SIZE = 256;
constexpr int SECRET_MAX_LEN_DOUBLE = 66; constexpr int SECRET_MAX_LEN_DOUBLE = 66;
struct clientshare_str { struct clientshare_str {
std::string fl_id; std::string fl_id;
std::vector<unsigned char> share; std::vector<uint8_t> share;
int index; int index;
}; };
struct CipherPublicPara { struct CipherPublicPara {
int t; int t;
int g; int g;
unsigned char prime[PRIME_MAX_LEN]; uint8_t prime[PRIME_MAX_LEN];
unsigned char p[SECRET_MAX_LEN]; uint8_t p[SECRET_MAX_LEN];
float dp_eps; float dp_eps;
float dp_delta; float dp_delta;
float dp_norm_clip; float dp_norm_clip;
@ -62,7 +78,7 @@ class CipherMetaStorage {
// Register Prime. // Register Prime.
void RegisterPrime(const char *list_name, const std::string &prime); void RegisterPrime(const char *list_name, const std::string &prime);
// Get tprime from shared server. // Get tprime from shared server.
bool GetPrimeFromServer(const char *prime_name, unsigned char *prime); bool GetPrimeFromServer(const char *prime_name, uint8_t *prime);
// Get client shares from shared server. // Get client shares from shared server.
void GetClientSharesFromServer(const char *list_name, void GetClientSharesFromServer(const char *list_name,
std::map<std::string, std::vector<clientshare_str>> *clients_shares_list); std::map<std::string, std::vector<clientshare_str>> *clients_shares_list);
@ -70,14 +86,18 @@ class CipherMetaStorage {
void GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list); void GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list);
// Get client keys from shared server. // Get client keys from shared server.
void GetClientKeysFromServer(const char *list_name, void GetClientKeysFromServer(const char *list_name,
std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list); std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list);
void GetClientIVsFromServer(const char *list_name,
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list);
// Get client noises from shared server. // Get client noises from shared server.
bool GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise); bool GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise);
// Update client fl_id to shared server. // Update client fl_id to shared server.
bool UpdateClientToServer(const char *list_name, const std::string &fl_id); bool UpdateClientToServer(const char *list_name, const std::string &fl_id);
// Update client key to shared server. // Update client key to shared server.
bool UpdateClientKeyToServer(const char *list_name, const std::string &fl_id, bool UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
const std::vector<std::vector<unsigned char>> &cur_public_key); const std::vector<std::vector<uint8_t>> &cur_public_key);
// Update client key with signature to shared server.
bool UpdateClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req);
// Update client noise to shared server. // Update client noise to shared server.
bool UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise); bool UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise);
// Update client share to shared server. // Update client share to shared server.

View File

@ -16,33 +16,35 @@
#include "fl/armour/cipher/cipher_reconstruct.h" #include "fl/armour/cipher/cipher_reconstruct.h"
#include "fl/server/common.h" #include "fl/server/common.h"
#include "fl/armour/secure_protocol/random.h" #include "fl/armour/secure_protocol/masking.h"
#include "fl/armour/secure_protocol/key_agreement.h" #include "fl/armour/secure_protocol/key_agreement.h"
#include "fl/armour/cipher/cipher_meta_storage.h" #include "fl/armour/cipher/cipher_meta_storage.h"
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
bool CipherReconStruct::CombineMask( bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp,
std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_keys, std::map<std::string, std::vector<float>> *client_noise,
const std::vector<std::string> &clients_share_list, const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys, const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list, const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
const std::vector<string> &client_list) { const std::vector<string> &client_list,
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs) {
bool retcode = true; bool retcode = true;
#ifdef _WIN32 #ifdef _WIN32
MS_LOG(ERROR) << "Unsupported feature in Windows platform."; MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
retcode = false; retcode = false;
#else #else
if (shares_tmp == nullptr || client_noise == nullptr) {
MS_LOG(ERROR) << "shares_tmp or client_noise is nullptr.";
return false;
}
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) { for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
// define flag_share: judge we need b or s // define flag_share: judge we need b or s
bool flag_share = true; bool flag_share = true;
const std::string fl_id = iter->first; const std::string fl_id = iter->first;
std::vector<std::string>::const_iterator ptr = client_list.begin(); if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) {
for (; ptr < client_list.end(); ++ptr) { // the client is online
if (*ptr == fl_id) { flag_share = false;
flag_share = false;
break;
}
} }
MS_LOG(INFO) << "fl_id_src : " << fl_id; MS_LOG(INFO) << "fl_id_src : " << fl_id;
BIGNUM *prime = BN_new(); BIGNUM *prime = BN_new();
@ -61,7 +63,6 @@ bool CipherReconStruct::CombineMask(
MS_LOG(ERROR) << "shares_tmp copy failed"; MS_LOG(ERROR) << "shares_tmp copy failed";
retcode = false; retcode = false;
} }
MS_LOG(INFO) << "fl_id_des : " << (iter->second)[i].fl_id;
std::string print_share_data(reinterpret_cast<const char *>(shares_tmp->at(i)->data), shares_tmp->at(i)->len); std::string print_share_data(reinterpret_cast<const char *>(shares_tmp->at(i)->data), shares_tmp->at(i)->len);
} }
MS_LOG(INFO) << "end assign secrets shares to public shares "; MS_LOG(INFO) << "end assign secrets shares to public shares ";
@ -74,23 +75,42 @@ bool CipherReconStruct::CombineMask(
MS_LOG(INFO) << "combine secrets shares Success."; MS_LOG(INFO) << "combine secrets shares Success.";
if (flag_share) { if (flag_share) {
MS_LOG(INFO) << "start get complete s_uv."; // reconstruct pairwise noise
MS_LOG(INFO) << "start reconstruct pairwise noise.";
std::vector<float> noise(cipher_init_->featuremap_, 0.0); std::vector<float> noise(cipher_init_->featuremap_, 0.0);
if (GetSuvNoise(clients_share_list, record_public_keys, fl_id, &noise, secret, length) == false) if (GetSuvNoise(clients_share_list, record_public_keys, client_ivs, fl_id, &noise, secret, length) == false)
retcode = false; retcode = false;
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise)); client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
MS_LOG(INFO) << " fl_id : " << fl_id; MS_LOG(INFO) << " fl_id : " << fl_id;
MS_LOG(INFO) << "end get complete s_uv."; MS_LOG(INFO) << "end get complete s_uv.";
} else { } else {
// reconstruct individual noise
MS_LOG(INFO) << "start reconstruct individual noise.";
std::vector<float> noise; std::vector<float> noise;
if (Random::RandomAESCTR(&noise, cipher_init_->featuremap_, (const unsigned char *)secret, SECRET_MAX_LEN) < 0) auto it = client_ivs.find(fl_id);
if (it == client_ivs.end()) {
MS_LOG(ERROR) << "cannot get ivs for client: " << fl_id;
return false;
}
if (it->second.size() != IV_NUM) {
MS_LOG(ERROR) << "get " << it->second.size() << " ivs, the iv num required is: " << IV_NUM;
return false;
}
std::vector<uint8_t> ind_iv = it->second[0];
if (Masking::GetMasking(&noise, cipher_init_->featuremap_, (const uint8_t *)secret, SECRET_MAX_LEN,
ind_iv.data(), ind_iv.size()) < 0)
retcode = false; retcode = false;
for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) { for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) {
noise[index_noise] *= -1; noise[index_noise] *= -1;
} }
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise)); client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
MS_LOG(INFO) << " fl_id : " << fl_id;
} }
} else {
MS_LOG(ERROR) << "reconstruct secret failed: the number of secret shares for fl_id: " << fl_id
<< " is not enough";
MS_LOG(ERROR) << "get " << iter->second.size()
<< "shares, however the secrets_minnums_ required is: " << cipher_init_->secrets_minnums_;
return false;
} }
} }
#endif #endif
@ -98,173 +118,188 @@ bool CipherReconStruct::CombineMask(
} }
bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &client_list) { bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &client_list) {
// get reconstruct_secret_list_ori from memory server // get reconstruct_secrets from memory server
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START"; MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
bool retcode = true; bool retcode = true;
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori; std::map<std::string, std::vector<clientshare_str>> reconstruct_secrets;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares, cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&reconstruct_secret_list_ori); &reconstruct_secrets);
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys; std::map<std::string, std::vector<std::vector<uint8_t>>> record_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys); cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
std::vector<std::string> clients_reconstruct_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList, std::map<std::string, std::vector<std::vector<uint8_t>>> client_ivs;
&clients_reconstruct_list); cipher_init_->cipher_meta_storage_.GetClientIVsFromServer(fl::server::kCtxClientsKeys, &client_ivs);
std::vector<std::string> clients_share_list; std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList, cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list); &clients_share_list);
if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() || if (record_public_keys.size() < cipher_init_->exchange_key_threshold ||
record_public_keys.size() < cipher_init_->client_num_need_ || clients_share_list.size() < cipher_init_->share_secrets_threshold ||
clients_share_list.size() < cipher_init_->share_clients_num_need_) { record_public_keys.size() != client_ivs.size()) {
MS_LOG(ERROR) << "send share client size: " << clients_share_list.size()
<< ", send public-key client size: " << record_public_keys.size()
<< ", send ivs client size: " << client_ivs.size();
MS_LOG(ERROR) << "get data from server memory failed"; MS_LOG(ERROR) << "get data from server memory failed";
return false; return false;
} }
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list; std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list;
ConvertSharesToShares(reconstruct_secret_list_ori, &reconstruct_secret_list); if (!ConvertSharesToShares(reconstruct_secrets, &reconstruct_secret_list)) {
MS_LOG(ERROR) << "ConvertSharesToShares failed.";
return false;
}
MS_LOG(ERROR) << "recombined shares";
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
MS_LOG(ERROR) << "fl_id: " << iter->first;
MS_LOG(ERROR) << "share size: " << iter->second.size();
}
std::vector<Share *> shares_tmp; std::vector<Share *> shares_tmp;
if (MallocShares(&shares_tmp, cipher_init_->secrets_minnums_) == false) { if (!MallocShares(&shares_tmp, cipher_init_->secrets_minnums_)) {
MS_LOG(ERROR) << "Reconstruct malloc shares_tmp invalid."; MS_LOG(ERROR) << "Reconstruct malloc shares_tmp invalid.";
return false; return false;
} }
MS_LOG(INFO) << "Reconstruct client list: ";
std::vector<std::string>::const_iterator ptr_tmp = client_list.begin();
for (; ptr_tmp < client_list.end(); ++ptr_tmp) {
MS_LOG(INFO) << *ptr_tmp;
}
MS_LOG(INFO) << "Reconstruct secrets shares: "; MS_LOG(INFO) << "Reconstruct secrets shares: ";
std::map<std::string, std::vector<float>> client_keys; std::map<std::string, std::vector<float>> client_noise;
retcode = CombineMask(&shares_tmp, &client_noise, clients_share_list, record_public_keys, reconstruct_secret_list,
retcode = CombineMask(&shares_tmp, &client_keys, clients_share_list, record_public_keys, reconstruct_secret_list, client_list, client_ivs);
client_list);
DeleteShares(&shares_tmp); DeleteShares(&shares_tmp);
if (retcode) { if (retcode) {
std::vector<float> noise; std::vector<float> noise;
if (GetNoiseMasksSum(&noise, client_keys) == false) { if (!GetNoiseMasksSum(&noise, client_noise)) {
MS_LOG(ERROR) << " GetNoiseMasksSum failed"; MS_LOG(ERROR) << " GetNoiseMasksSum failed";
return false; return false;
} }
client_keys.clear(); client_noise.clear();
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server"; MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise) == false) if (!cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise)) {
MS_LOG(ERROR) << " ReconstructSecretsGenNoise failed. because UpdateClientNoiseToServer failed";
return false; return false;
}
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success"; MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
} else { } else {
MS_LOG(INFO) << " ReconstructSecretsGenNoise failed. because gen noise inside failed"; MS_LOG(ERROR) << " ReconstructSecretsGenNoise failed. because gen noise inside failed";
} }
return retcode; return retcode;
} }
// reconstruct secrets // reconstruct secrets
bool CipherReconStruct::ReconstructSecrets( bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const int cur_iterator, const std::string &next_req_time, const schema::SendReconstructSecret *reconstruct_secret_req, const schema::SendReconstructSecret *reconstruct_secret_req,
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder, const std::shared_ptr<fl::server::FBBuilder> &fbb,
const std::vector<std::string> &client_list) { const std::vector<std::string> &client_list) {
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START"; MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
clock_t start_time = clock(); clock_t start_time = clock();
if (reconstruct_secret_req == nullptr || reconstruct_secret_resp_builder == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr. ";
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_RequestError,
"Request is nullptr or Response builder is nullptr.", cur_iterator, next_req_time);
return false;
}
if (client_list.size() < cipher_init_->reconstruct_clients_num_need_) {
MS_LOG(ERROR) << "illegal parameters. update model client_list size: " << client_list.size();
BuildReconstructSecretsRsp(
reconstruct_secret_resp_builder, schema::ResponseCode_RequestError,
"illegal parameters: update model client_list size must larger than reconstruct_clients_num_need", cur_iterator,
next_req_time);
return false;
}
std::vector<std::string> clients_reconstruct_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
&clients_reconstruct_list);
std::map<std::string, std::vector<clientshare_str>> clients_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&clients_shares_all);
size_t count_client_num = clients_shares_all.size(); if (reconstruct_secret_req == nullptr) {
if (count_client_num != clients_reconstruct_list.size()) { std::string reason = "Request is nullptr";
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime, MS_LOG(ERROR) << reason;
"shares client size and client size are not equal.", cur_iterator, next_req_time); BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, cur_iterator, next_req_time);
MS_LOG(ERROR) << "shares client size and client size are not equal.";
return false; return false;
} }
int iterator = reconstruct_secret_req->iteration(); int iterator = reconstruct_secret_req->iteration();
std::string fl_id = reconstruct_secret_req->fl_id()->str(); std::string fl_id = reconstruct_secret_req->fl_id()->str();
if (iterator != cur_iterator) { if (iterator != cur_iterator) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime, BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"The iteration round of the client does not match the current iteration.", cur_iterator, "The iteration round of the client does not match the current iteration.", cur_iterator,
next_req_time); next_req_time);
MS_LOG(ERROR) << "Client " << fl_id << " The iteration round of the client does not match the current iteration."; MS_LOG(ERROR) << "Client " << fl_id << " The iteration round of the client does not match the current iteration.";
return false; return false;
} }
if (find(client_list.begin(), client_list.end(), fl_id) == client_list.end()) { // client not in client list. if (client_list.size() < cipher_init_->reconstruct_secrets_threshold) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime, MS_LOG(ERROR) << "illegal parameters. update model client_list size: " << client_list.size();
"The client is not in update model client list.", cur_iterator, next_req_time); BuildReconstructSecretsRsp(
MS_LOG(ERROR) << "The client " << fl_id << " is not in update model client list."; fbb, schema::ResponseCode_RequestError,
"illegal parameters: update model client_list size must larger than reconstruct_clients_num_need", cur_iterator,
next_req_time);
return false; return false;
} }
if (find(clients_reconstruct_list.begin(), clients_reconstruct_list.end(), fl_id) != clients_reconstruct_list.end()) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED, std::vector<std::string> get_clients_list;
"Client has sended messages.", cur_iterator, next_req_time); cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetUpdateModelClientList,
&get_clients_list);
// client not in get client list.
if (find(get_clients_list.begin(), get_clients_list.end(), fl_id) == get_clients_list.end()) {
std::string reason;
MS_LOG(INFO) << "The client " << fl_id << " is not in get update model client list.";
// client in update model client list.
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) {
reason = "The client " + fl_id + " is not in get clients list, but in update model client list.";
MS_LOG(INFO) << reason;
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, reason, cur_iterator, next_req_time);
return false;
}
reason = "The client " + fl_id + " is not in get clients list, and not in update model client list.";
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, "The client is not in update model client list.",
cur_iterator, next_req_time);
return false;
}
std::map<std::string, std::vector<clientshare_str>> reconstruct_shares;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
&reconstruct_shares);
size_t count_client_num = reconstruct_shares.size();
if (reconstruct_shares.find(fl_id) != reconstruct_shares.end()) {
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Client has sended messages.", cur_iterator,
next_req_time);
MS_LOG(INFO) << "Error, client " << fl_id << " has sended messages."; MS_LOG(INFO) << "Error, client " << fl_id << " has sended messages.";
return false; return false;
} }
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares(); auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
bool retcode_client = bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id); cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id);
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer( bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares); fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
if (!(retcode_client && retcode_share)) { if (!(retcode_client && retcode_share)) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime, BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "reconstruct update shares or client failed.",
"reconstruct update shares or client failed.", cur_iterator, next_req_time); cur_iterator, next_req_time);
MS_LOG(ERROR) << "reconstruct update shares or client failed."; MS_LOG(ERROR) << "reconstruct update shares or client failed.";
return false; return false;
} }
count_client_num = count_client_num + 1; count_client_num = count_client_num + 1;
if (count_client_num < cipher_init_->reconstruct_clients_num_need_) { if (count_client_num < cipher_init_->reconstruct_secrets_threshold) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED, BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED,
"Success,but the server is not ready to reconstruct secret yet.", cur_iterator, "Success, but the server is not ready to reconstruct secret yet.", cur_iterator,
next_req_time); next_req_time);
MS_LOG(INFO) << "ReconstructSecrets" << fl_id << " Success, but count " << count_client_num << "is not enough."; MS_LOG(INFO) << "ReconstructSecrets " << fl_id << " Success, but count " << count_client_num << "is not enough.";
return true; return true;
} else {
bool retcode_result = true;
const fl::PBMetadata &clients_noises_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
if (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "Success,the secret will be reconstructed.";
retcode_result = ReconstructSecretsGenNoise(client_list);
if (retcode_result) {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
"Success,the secret is reconstructing.", cur_iterator, next_req_time);
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success, reconstruct ok.";
} else {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
"the secret restructs failed.", cur_iterator, next_req_time);
MS_LOG(ERROR) << "the secret restructs failed.";
}
} else {
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
"Clients' number is full.", cur_iterator, next_req_time);
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success : no need reconstruct.";
}
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "Reconstruct get + gennoise data time is : " << duration;
return retcode_result;
} }
const fl::PBMetadata &clients_noises_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
if (clients_noises_pb.has_one_client_noises() == false) {
MS_LOG(INFO) << "Success, the secret will be reconstructed.";
if (ReconstructSecretsGenNoise(client_list)) {
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Success,the secret is reconstructing.",
cur_iterator, next_req_time);
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success, reconstruct ok.";
} else {
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "the secret restructs failed.", cur_iterator,
next_req_time);
MS_LOG(ERROR) << "CipherReconStruct::ReconstructSecrets" << fl_id << " failed.";
}
} else {
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Clients' number is full.", cur_iterator,
next_req_time);
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success : no need reconstruct.";
}
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "Reconstruct get + gennoise data time is : " << duration;
return true;
} }
bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result, bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
const std::map<std::string, std::vector<float>> &client_keys) { const std::map<std::string, std::vector<float>> &client_noise) {
std::vector<float> sum(cipher_init_->featuremap_, 0.0); std::vector<float> sum(cipher_init_->featuremap_, 0.0);
for (auto iter = client_keys.begin(); iter != client_keys.end(); iter++) { for (auto iter = client_noise.begin(); iter != client_noise.end(); iter++) {
if (iter->second.size() != cipher_init_->featuremap_) { if (iter->second.size() != cipher_init_->featuremap_) {
return false; return false;
} }
@ -301,35 +336,52 @@ void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::ser
return; return;
} }
bool CipherReconStruct::GetSuvNoise( bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_share_list,
const std::vector<std::string> &clients_share_list, const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys, const string &fl_id, const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs,
std::vector<float> *noise, uint8_t *secret, int length) { const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length) {
for (auto p_key = clients_share_list.begin(); p_key != clients_share_list.end(); ++p_key) { for (auto p_key = clients_share_list.begin(); p_key != clients_share_list.end(); ++p_key) {
if (*p_key != fl_id) { if (*p_key != fl_id) {
PrivateKey *privKey1 = KeyAgreement::FromPrivateBytes((unsigned char *)secret, length); PrivateKey *privKey = KeyAgreement::FromPrivateBytes(secret, length);
if (privKey1 == NULL) { if (privKey == NULL) {
MS_LOG(ERROR) << "create privKey1 failed\n"; MS_LOG(ERROR) << "create privKey failed\n";
return false; return false;
} }
std::vector<unsigned char> public_key = record_public_keys.at(*p_key)[1]; std::vector<uint8_t> public_key = record_public_keys.at(*p_key)[1];
PublicKey *pubKey1 = KeyAgreement::FromPublicBytes(public_key.data(), public_key.size()); std::string iv_fl_id;
if (pubKey1 == NULL) { if (fl_id < *p_key) {
MS_LOG(ERROR) << "create pubKey1 failed\n"; iv_fl_id = fl_id;
} else {
iv_fl_id = *p_key;
}
auto iter = client_ivs.find(iv_fl_id);
if (iter == client_ivs.end()) {
MS_LOG(ERROR) << "cannot get ivs for client: " << iv_fl_id;
return false; return false;
} }
MS_LOG(INFO) << "fl_id : " << fl_id << "other id : " << *p_key; if (iter->second.size() != IV_NUM) {
unsigned char secret1[SECRET_MAX_LEN] = {0}; MS_LOG(ERROR) << "get " << iter->second.size() << " ivs, the iv num required is: " << IV_NUM;
unsigned char salt[SECRET_MAX_LEN] = {0}; return false;
if (KeyAgreement::ComputeSharedKey(privKey1, pubKey1, SECRET_MAX_LEN, salt, SECRET_MAX_LEN, secret1) < 0) { }
std::vector<uint8_t> pw_iv = iter->second[PW_IV_INDEX];
std::vector<uint8_t> pw_salt = iter->second[PW_SALT_INDEX];
PublicKey *pubKey = KeyAgreement::FromPublicBytes(public_key.data(), public_key.size());
if (pubKey == NULL) {
MS_LOG(ERROR) << "create pubKey failed\n";
return false;
}
MS_LOG(INFO) << "private_key fl_id : " << fl_id << " public_key fl_id : " << *p_key;
uint8_t secret1[SECRET_MAX_LEN] = {0};
if (KeyAgreement::ComputeSharedKey(privKey, pubKey, SECRET_MAX_LEN, pw_salt.data(), pw_salt.size(), secret1) <
0) {
MS_LOG(ERROR) << "ComputeSharedKey failed\n"; MS_LOG(ERROR) << "ComputeSharedKey failed\n";
return false; return false;
} }
std::vector<float> noise_tmp; std::vector<float> noise_tmp;
if (Random::RandomAESCTR(&noise_tmp, cipher_init_->featuremap_, (const unsigned char *)secret1, SECRET_MAX_LEN) < if (Masking::GetMasking(&noise_tmp, cipher_init_->featuremap_, (const uint8_t *)secret1, SECRET_MAX_LEN,
0) { pw_iv.data(), pw_iv.size()) < 0) {
MS_LOG(ERROR) << "RandomAESCTR failed\n"; MS_LOG(ERROR) << "Get Masking failed\n";
return false; return false;
} }
bool symbol_noise = GetSymbol(fl_id, *p_key); bool symbol_noise = GetSymbol(fl_id, *p_key);
@ -345,9 +397,6 @@ bool CipherReconStruct::GetSuvNoise(
noise->at(index) += noise_tmp[index]; noise->at(index) += noise_tmp[index];
} }
} }
for (int i = 0; i < 5; i++) {
MS_LOG(INFO) << "index " << i << " : " << noise_tmp[i];
}
} }
} }
return true; return true;
@ -361,37 +410,41 @@ bool CipherReconStruct::GetSymbol(const std::string &str1, const std::string &st
} }
} }
void CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src, // recombined shares by their source fl_id (ownners)
bool CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
std::map<std::string, std::vector<clientshare_str>> *des) { std::map<std::string, std::vector<clientshare_str>> *des) {
for (auto iter_ori = src.begin(); iter_ori != src.end(); ++iter_ori) { if (des == nullptr) return false;
std::string fl_des = iter_ori->first; for (auto iter = src.begin(); iter != src.end(); ++iter) {
auto &cur_clientshare_str = iter_ori->second; std::string des_id = iter->first;
auto &cur_clientshare_str = iter->second;
for (size_t index_clientshare = 0; index_clientshare < cur_clientshare_str.size(); ++index_clientshare) { for (size_t index_clientshare = 0; index_clientshare < cur_clientshare_str.size(); ++index_clientshare) {
std::string fl_src = cur_clientshare_str[index_clientshare].fl_id; std::string src_id = cur_clientshare_str[index_clientshare].fl_id;
clientshare_str value; clientshare_str value;
value.fl_id = fl_des; value.fl_id = des_id;
value.share = cur_clientshare_str[index_clientshare].share; value.share = cur_clientshare_str[index_clientshare].share;
value.index = cur_clientshare_str[index_clientshare].index; value.index = cur_clientshare_str[index_clientshare].index;
if (des->find(fl_src) == des->end()) { // fl_id_des is not in reconstruct_secret_list_ if (des->find(src_id) == des->end()) { // src_id is not in recombined shares list
std::vector<clientshare_str> value_list; std::vector<clientshare_str> value_list;
value_list.push_back(value); value_list.push_back(value);
des->insert(std::pair<std::string, std::vector<clientshare_str>>(fl_src, value_list)); des->insert(std::pair<std::string, std::vector<clientshare_str>>(src_id, value_list));
} else { // fl_id_des is in reconstruct_secret_list_ } else {
des->at(fl_src).push_back(value); des->at(src_id).push_back(value);
} }
} }
} }
return true;
} }
bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int shares_size) { bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int shares_size) {
if (shares_tmp == nullptr) return false;
for (int i = 0; i < shares_size; ++i) { for (int i = 0; i < shares_size; ++i) {
Share *share_i = new Share; Share *share_i = new Share();
if (share_i == nullptr) { if (share_i == nullptr) {
MS_LOG(ERROR) << "shares_tmp " << i << " memory to cipher is invalid."; MS_LOG(ERROR) << "shares_tmp " << i << " memory to cipher is invalid.";
DeleteShares(shares_tmp); DeleteShares(shares_tmp);
return false; return false;
} }
share_i->data = new unsigned char[SHARE_MAX_SIZE]; share_i->data = new uint8_t[SHARE_MAX_SIZE];
if (share_i->data == nullptr) { if (share_i->data == nullptr) {
MS_LOG(ERROR) << "shares_tmp's data " << i << " memory to cipher is invalid."; MS_LOG(ERROR) << "shares_tmp's data " << i << " memory to cipher is invalid.";
DeleteShares(shares_tmp); DeleteShares(shares_tmp);
@ -405,6 +458,7 @@ bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int share
} }
void CipherReconStruct::DeleteShares(std::vector<Share *> *shares_tmp) { void CipherReconStruct::DeleteShares(std::vector<Share *> *shares_tmp) {
if (shares_tmp == nullptr) return;
if (shares_tmp->size() != 0) { if (shares_tmp->size() != 0) {
for (size_t i = 0; i < shares_tmp->size(); ++i) { for (size_t i = 0; i < shares_tmp->size(); ++i) {
if (shares_tmp->at(i) != nullptr && shares_tmp->at(i)->data != nullptr) { if (shares_tmp->at(i) != nullptr && shares_tmp->at(i)->data != nullptr) {

View File

@ -28,6 +28,8 @@
#include "fl/armour/cipher/cipher_init.h" #include "fl/armour/cipher/cipher_init.h"
#include "fl/armour/cipher/cipher_meta_storage.h" #include "fl/armour/cipher/cipher_meta_storage.h"
#define IV_NUM 3
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
// The process of reconstruct secret mask in the secure aggregation // The process of reconstruct secret mask in the secure aggregation
@ -44,7 +46,7 @@ class CipherReconStruct {
// reconstruct secret mask // reconstruct secret mask
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time, bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
const schema::SendReconstructSecret *reconstruct_secret_req, const schema::SendReconstructSecret *reconstruct_secret_req,
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder, const std::shared_ptr<fl::server::FBBuilder> &fbb,
const std::vector<std::string> &client_list); const std::vector<std::string> &client_list);
// build response code of reconstruct secret. // build response code of reconstruct secret.
@ -60,26 +62,28 @@ class CipherReconStruct {
bool GetSymbol(const std::string &str1, const std::string &str2); bool GetSymbol(const std::string &str1, const std::string &str2);
// get suv noise by computing shares result. // get suv noise by computing shares result.
bool GetSuvNoise(const std::vector<std::string> &clients_share_list, bool GetSuvNoise(const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys, const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length); const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs, const string &fl_id,
std::vector<float> *noise, uint8_t *secret, int length);
// malloc shares. // malloc shares.
bool MallocShares(std::vector<Share *> *shares_tmp, int shares_size); bool MallocShares(std::vector<Share *> *shares_tmp, int shares_size);
// delete shares. // delete shares.
void DeleteShares(std::vector<Share *> *shares_tmp); void DeleteShares(std::vector<Share *> *shares_tmp);
// convert shares from receiving clients to sending clients. // convert shares from receiving clients to sending clients.
void ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src, bool ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
std::map<std::string, std::vector<clientshare_str>> *des); std::map<std::string, std::vector<clientshare_str>> *des);
// generate noise from shares. // generate noise from shares.
bool ReconstructSecretsGenNoise(const std::vector<string> &client_list); bool ReconstructSecretsGenNoise(const std::vector<string> &client_list);
// get noise masks sum. // get noise masks sum.
bool GetNoiseMasksSum(std::vector<float> *result, const std::map<std::string, std::vector<float>> &client_keys); bool GetNoiseMasksSum(std::vector<float> *result, const std::map<std::string, std::vector<float>> &client_noise);
// combine noise mask. // combine noise mask.
bool CombineMask(std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_keys, bool CombineMask(std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_noise,
const std::vector<std::string> &clients_share_list, const std::vector<std::string> &clients_share_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys, const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list, const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
const std::vector<string> &client_list); const std::vector<string> &client_list,
const std::map<std::string, std::vector<std::vector<unsigned char>>> &client_ivs);
}; };
} // namespace armour } // namespace armour
} // namespace mindspore } // namespace mindspore

View File

@ -25,8 +25,8 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
const string next_req_time) { const string next_req_time) {
MS_LOG(INFO) << "CipherShares::ShareSecrets START"; MS_LOG(INFO) << "CipherShares::ShareSecrets START";
if (share_secrets_req == nullptr) { if (share_secrets_req == nullptr) {
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr."; std::string reason = "Request is nullptr";
std::string reason = "Request is nullptr or Response builder is nullptr."; MS_LOG(ERROR) << reason;
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time, BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time,
cur_iterator); cur_iterator);
return false; return false;
@ -34,38 +34,34 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
// step 1: get client list and share secrets from memory server. // step 1: get client list and share secrets from memory server.
clock_t start_time = clock(); clock_t start_time = clock();
int iteration = share_secrets_req->iteration();
std::vector<std::string> get_keys_clients;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetKeysClientList, &get_keys_clients);
std::vector<std::string> clients_share_list; std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList, cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list); &clients_share_list);
std::vector<std::string> clients_exchange_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList,
&clients_exchange_list);
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all; std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares, cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
&encrypted_shares_all); &encrypted_shares_all);
MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size() MS_LOG(INFO) << "Client of get keys size : " << get_keys_clients.size()
<< "client of shares size : " << clients_share_list.size() << "shares size" << "client of update shares size : " << clients_share_list.size()
<< encrypted_shares_all.size(); << "updated shares size: " << encrypted_shares_all.size();
if (encrypted_shares_all.size() != clients_share_list.size()) {
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime,
"client of shares and shares size are not equal", next_req_time, cur_iterator);
MS_LOG(ERROR) << "client of shares and shares size are not equal. client of shares size : "
<< clients_share_list.size() << "shares size" << encrypted_shares_all.size();
}
// step 2: update new item to memory server. serialise: update pb struct to memory server. // step 2: update new item to memory server. serialise: update pb struct to memory server.
int iteration = share_secrets_req->iteration();
std::string fl_id_src = share_secrets_req->fl_id()->str(); std::string fl_id_src = share_secrets_req->fl_id()->str();
if (find(clients_exchange_list.begin(), clients_exchange_list.end(), fl_id_src) == if (find(get_keys_clients.begin(), get_keys_clients.end(), fl_id_src) == get_keys_clients.end()) {
clients_exchange_list.end()) { // the client not in clients_exchange_list, return false. // the client not in get keys clients
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError, BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError,
("client share secret is not in clients_exchange list. && client is illegal"), next_req_time, ("client share secret is not in getkeys list. && client is illegal"), next_req_time,
iteration); iteration);
return false; return false;
} }
if (find(clients_share_list.begin(), clients_share_list.end(), fl_id_src) !=
clients_share_list.end()) { // the client is already exists, return false. if (encrypted_shares_all.find(fl_id_src) != encrypted_shares_all.end()) { // the client is already exists
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED,
("client sharesecret already exists."), next_req_time, iteration); ("client sharesecret already exists."), next_req_time, iteration);
return false; return false;
@ -74,24 +70,22 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
// update new item to memory server. // update new item to memory server.
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares = const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
(share_secrets_req->encrypted_shares()); (share_secrets_req->encrypted_shares());
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
bool retcode_client = bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src); cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src);
bool retcode = retcode_share && retcode_client; bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
if (retcode) { fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration); if (!(retcode_share && retcode_client)) {
MS_LOG(INFO) << "CipherShares::ShareSecrets Success";
} else {
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime, BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime,
"update client of shares and shares failed", next_req_time, iteration); "update client of shares and shares failed", next_req_time, iteration);
MS_LOG(ERROR) << "CipherShares::ShareSecrets update client of shares and shares failed "; MS_LOG(ERROR) << "CipherShares::ShareSecrets update client of shares and shares failed ";
return false;
} }
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
MS_LOG(INFO) << "CipherShares::ShareSecrets Success";
clock_t end_time = clock(); clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "ShareSecrets get + deal + update data time is : " << duration; MS_LOG(INFO) << "ShareSecrets get + deal + update data time is : " << duration;
return retcode; return true;
} }
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req, bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
@ -101,36 +95,38 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
clock_t start_time = clock(); clock_t start_time = clock();
// step 0: check whether the parameters are legal. // step 0: check whether the parameters are legal.
if (get_secrets_req == nullptr) { if (get_secrets_req == nullptr) {
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, 0); BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, nullptr);
MS_LOG(ERROR) << "GetSecrets: get_secrets_req is nullptr."; MS_LOG(ERROR) << "GetSecrets: get_secrets_req is nullptr.";
return false; return false;
} }
// step 1: get client list and client shares list from memory server. // step 1: get client list and client shares list from memory server.
std::vector<std::string> clients_share_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
&clients_share_list);
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all; std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares, cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
&encrypted_shares_all); &encrypted_shares_all);
int iteration = get_secrets_req->iteration(); int iteration = get_secrets_req->iteration();
size_t share_clients_num = clients_share_list.size(); size_t encrypted_shares_num = encrypted_shares_all.size();
size_t cients_has_shares = encrypted_shares_all.size(); if (cipher_init_->share_secrets_threshold > encrypted_shares_num) { // the client num is not enough, return false.
if (share_clients_num != cients_has_shares) { BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_OutOfTime, iteration, next_req_time, 0); MS_LOG(INFO) << "GetSecrets: the encrypted shares num is not enough: share_secrets_threshold: "
MS_LOG(ERROR) << "cients_has_shares: " << cients_has_shares << "share_clients_num: " << share_clients_num; << cipher_init_->share_secrets_threshold << "encrypted_shares_num: " << encrypted_shares_num;
}
if (cipher_init_->share_clients_num_need_ > share_clients_num) { // the client num is not enough, return false.
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, 0);
MS_LOG(INFO) << "GetSecrets: the client num is not enough: share_clients_num_need_: "
<< cipher_init_->share_clients_num_need_ << "share_clients_num: " << share_clients_num;
return false; return false;
} }
std::string fl_id = get_secrets_req->fl_id()->str(); std::string fl_id = get_secrets_req->fl_id()->str();
if (find(clients_share_list.begin(), clients_share_list.end(), fl_id) == // the client is not in share secrets client list.
clients_share_list.end()) { // the client is not in client list, return false. if (encrypted_shares_all.find(fl_id) == encrypted_shares_all.end()) {
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, 0); BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, nullptr);
MS_LOG(ERROR) << "GetSecrets: client is not in client list."; MS_LOG(ERROR) << "GetSecrets: client is not in share secrets client list.";
return false;
}
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetSecretsClientList, fl_id);
if (!retcode_client) {
MS_LOG(ERROR) << "update get secrets clients failed";
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
return false;
} }
// get the result client shares. // get the result client shares.
@ -171,7 +167,6 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SUCCEED, iteration, next_req_time, BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SUCCEED, iteration, next_req_time,
&encrypted_shares); &encrypted_shares);
MS_LOG(INFO) << "CipherShares::GetSecrets Success"; MS_LOG(INFO) << "CipherShares::GetSecrets Success";
clock_t end_time = clock(); clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
@ -185,7 +180,7 @@ void CipherShares::BuildGetSecretsRsp(
int rsp_retcode = retcode; int rsp_retcode = retcode;
int rsp_iteration = iteration; int rsp_iteration = iteration;
auto rsp_next_req_time = get_secrets_resp_builder->CreateString(next_req_time); auto rsp_next_req_time = get_secrets_resp_builder->CreateString(next_req_time);
if (encrypted_shares == 0) { if (encrypted_shares == nullptr) {
auto get_secrets_rsp = auto get_secrets_rsp =
schema::CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration, 0, rsp_next_req_time); schema::CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration, 0, rsp_next_req_time);
get_secrets_resp_builder->Finish(get_secrets_rsp); get_secrets_resp_builder->Finish(get_secrets_rsp);
@ -195,7 +190,6 @@ void CipherShares::BuildGetSecretsRsp(
encrypted_shares_rsp, rsp_next_req_time); encrypted_shares_rsp, rsp_next_req_time);
get_secrets_resp_builder->Finish(get_secrets_rsp); get_secrets_resp_builder->Finish(get_secrets_rsp);
} }
return; return;
} }

View File

@ -26,8 +26,8 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
clock_t start_time = clock(); clock_t start_time = clock();
std::vector<float> noise; std::vector<float> noise;
(void)cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise); bool ret = cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
if (noise.size() != cipher_init_->featuremap_) { if (!ret || noise.size() != cipher_init_->featuremap_) {
MS_LOG(ERROR) << " CipherMgr UnMask ERROR"; MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
return false; return false;
} }

View File

@ -18,12 +18,7 @@
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
AESEncrypt::AESEncrypt(const uint8_t *key, int key_len, const uint8_t *ivec, int ivec_len, const AES_MODE mode) {
#define KEY_STEP_MAX 32
#define KEY_STEP_MIN 16
#define PAD_SIZE 5
AESEncrypt::AESEncrypt(const unsigned char *key, int key_len, unsigned char *ivec, int ivec_len, const AES_MODE mode) {
privKey = key; privKey = key;
privKeyLen = key_len; privKeyLen = key_len;
iVec = ivec; iVec = ivec;
@ -47,12 +42,20 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
#else #else
int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) { int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) {
int ret; int ret;
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) { if (privKey == NULL || iVec == NULL) {
MS_LOG(ERROR) << "key length must be 16 or 32!"; MS_LOG(ERROR) << "private key or init vector is invalid.";
return -1; return -1;
} }
if (iVecLen != INIT_VEC_SIZE) { if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
MS_LOG(ERROR) << "initial vector size must be 16!"; MS_LOG(ERROR) << "key length is invalid.";
return -1;
}
if (iVecLen != AES_IV_SIZE) {
MS_LOG(ERROR) << "initial vector size is invalid.";
return -1;
}
if (data == NULL || len <= 0 || encrypt_data == NULL) {
MS_LOG(ERROR) << "input data is invalid.";
return -1; return -1;
} }
if (aesMode == AES_CBC || aesMode == AES_CTR) { if (aesMode == AES_CBC || aesMode == AES_CTR) {
@ -69,12 +72,20 @@ int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned c
int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) { int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) {
int ret = 0; int ret = 0;
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) { if (privKey == NULL || iVec == NULL) {
MS_LOG(ERROR) << "key length must be 16 or 32!"; MS_LOG(ERROR) << "private key or init vector is invalid.";
return -1; return -1;
} }
if (iVecLen != INIT_VEC_SIZE) { if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
MS_LOG(ERROR) << "initial vector size must be 16!"; MS_LOG(ERROR) << "key length is invalid.";
return -1;
}
if (iVecLen != AES_IV_SIZE) {
MS_LOG(ERROR) << "initial vector size is invalid.";
return -1;
}
if (data == NULL || encrypt_len <= 0 || encrypt_data == NULL) {
MS_LOG(ERROR) << "input data is invalid.";
return -1; return -1;
} }
if (aesMode == AES_CBC || aesMode == AES_CTR) { if (aesMode == AES_CBC || aesMode == AES_CTR) {
@ -88,17 +99,21 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
return 0; return 0;
} }
int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const unsigned char *key, unsigned char *ivec, int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
unsigned char *encrypt_data, int *encrypt_len) { uint8_t *encrypt_data, int *encrypt_len) {
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
return -1;
}
int out_len; int out_len;
int ret = 0; int ret;
if (aesMode == AES_CBC) { if (aesMode == AES_CBC) {
switch (privKeyLen) { switch (privKeyLen) {
case 16: case KEY_LENGTH_16:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec); ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
break; break;
case 32: case KEY_LENGTH_32:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec); ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
break; break;
default: default:
@ -107,16 +122,16 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
} }
if (ret != 1) { if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex CBC fail!"; MS_LOG(ERROR) << "EVP_EncryptInit_ex CBC fail!";
EVP_CIPHER_CTX_free(ctx);
return -1; return -1;
} }
EVP_CIPHER_CTX_set_key_length(ctx, EVP_MAX_KEY_LENGTH); EVP_CIPHER_CTX_set_padding(ctx, EVP_PADDING_PKCS7);
EVP_CIPHER_CTX_set_padding(ctx, PAD_SIZE);
} else if (aesMode == AES_CTR) { } else if (aesMode == AES_CTR) {
switch (privKeyLen) { switch (privKeyLen) {
case 16: case KEY_LENGTH_16:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec); ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
break; break;
case 32: case KEY_LENGTH_32:
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec); ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
break; break;
default: default:
@ -125,21 +140,25 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
} }
if (ret != 1) { if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptInit_ex CTR fail!"; MS_LOG(ERROR) << "EVP_EncryptInit_ex CTR fail!";
EVP_CIPHER_CTX_free(ctx);
return -1; return -1;
} }
} else { } else {
MS_LOG(ERROR) << "Unsupported AES mode"; MS_LOG(ERROR) << "Unsupported AES mode";
EVP_CIPHER_CTX_free(ctx);
return -1; return -1;
} }
ret = EVP_EncryptUpdate(ctx, encrypt_data, &out_len, data, len); ret = EVP_EncryptUpdate(ctx, encrypt_data, &out_len, data, len);
if (ret != 1) { if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptUpdate fail!"; MS_LOG(ERROR) << "EVP_EncryptUpdate fail!";
EVP_CIPHER_CTX_free(ctx);
return -1; return -1;
} }
*encrypt_len = out_len; *encrypt_len = out_len;
ret = EVP_EncryptFinal_ex(ctx, encrypt_data + *encrypt_len, &out_len); ret = EVP_EncryptFinal_ex(ctx, encrypt_data + *encrypt_len, &out_len);
if (ret != 1) { if (ret != 1) {
MS_LOG(ERROR) << "EVP_EncryptFinal_ex fail!"; MS_LOG(ERROR) << "EVP_EncryptFinal_ex fail!";
EVP_CIPHER_CTX_free(ctx);
return -1; return -1;
} }
*encrypt_len += out_len; *encrypt_len += out_len;
@ -147,17 +166,21 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
return 0; return 0;
} }
int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len, const unsigned char *key, int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
unsigned char *ivec, unsigned char *decrypt_data, int *decrypt_len) { uint8_t *decrypt_data, int *decrypt_len) {
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
return -1;
}
int out_len; int out_len;
int ret = 0; int ret;
if (aesMode == AES_CBC) { if (aesMode == AES_CBC) {
switch (privKeyLen) { switch (privKeyLen) {
case 16: case KEY_LENGTH_16:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec); ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
break; break;
case 32: case KEY_LENGTH_32:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec); ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
break; break;
default: default:
@ -165,40 +188,46 @@ int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len
ret = -1; ret = -1;
} }
if (ret != 1) { if (ret != 1) {
EVP_CIPHER_CTX_free(ctx);
return -1; return -1;
} }
EVP_CIPHER_CTX_set_key_length(ctx, EVP_MAX_KEY_LENGTH);
} else if (aesMode == AES_CTR) { } else if (aesMode == AES_CTR) {
switch (privKeyLen) { switch (privKeyLen) {
case 16: case KEY_LENGTH_16:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec); ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
break; break;
case 32: case KEY_LENGTH_32:
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec); ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
break; break;
default: default:
MS_LOG(ERROR) << "key length is incorrect!"; MS_LOG(ERROR) << "key length is incorrect!";
ret = -1; ret = -1;
} }
if (ret != 1) {
EVP_CIPHER_CTX_free(ctx);
return -1;
}
} else { } else {
ret = -1; MS_LOG(ERROR) << "Unsupported AES mode";
} EVP_CIPHER_CTX_free(ctx);
if (ret != 1) {
return -1; return -1;
} }
ret = EVP_DecryptUpdate(ctx, decrypt_data, &out_len, encrypt_data, len); ret = EVP_DecryptUpdate(ctx, decrypt_data, &out_len, encrypt_data, len);
if (ret != 1) { if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptUpdate fail!";
EVP_CIPHER_CTX_free(ctx);
return -1; return -1;
} }
*decrypt_len = out_len; *decrypt_len = out_len;
ret = EVP_DecryptFinal_ex(ctx, decrypt_data + *decrypt_len, &out_len); ret = EVP_DecryptFinal_ex(ctx, decrypt_data + *decrypt_len, &out_len);
if (ret != 1) { if (ret != 1) {
MS_LOG(ERROR) << "EVP_DecryptFinal_ex fail!";
EVP_CIPHER_CTX_free(ctx);
return -1; return -1;
} }
*decrypt_len += out_len; *decrypt_len += out_len;
EVP_CIPHER_CTX_free(ctx); EVP_CIPHER_CTX_free(ctx);
return 0; return 0;
} }
#endif #endif

View File

@ -22,7 +22,9 @@
#endif #endif
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#define INIT_VEC_SIZE 16 #define AES_IV_SIZE 16
#define KEY_LENGTH_32 32
#define KEY_LENGTH_16 16
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
@ -35,21 +37,21 @@ class SymmetricEncrypt : Encrypt {};
class AESEncrypt : SymmetricEncrypt { class AESEncrypt : SymmetricEncrypt {
public: public:
AESEncrypt(const unsigned char *key, int key_len, unsigned char *ivec, int ivec_len, AES_MODE mode); AESEncrypt(const unsigned char *key, int key_len, const uint8_t *ivec, int ivec_len, AES_MODE mode);
~AESEncrypt(); ~AESEncrypt();
int EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len); int EncryptData(const uint8_t *data, const int len, uint8_t *encrypt_data, int *encrypt_len);
int DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len); int DecryptData(const uint8_t *encrypt_data, const int encrypt_len, uint8_t *data, int *len);
private: private:
const unsigned char *privKey; const uint8_t *privKey;
int privKeyLen; int privKeyLen;
unsigned char *iVec; const uint8_t *iVec;
int iVecLen; int iVecLen;
AES_MODE aesMode; AES_MODE aesMode;
int evp_aes_encrypt(const unsigned char *data, const int len, const unsigned char *key, unsigned char *ivec, int evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
unsigned char *encrypt_data, int *encrypt_len); uint8_t *encrypt_data, int *encrypt_len);
int evp_aes_decrypt(const unsigned char *encrypt_data, const int len, const unsigned char *key, unsigned char *ivec, int evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
unsigned char *decrypt_data, int *decrypt_len); uint8_t *decrypt_data, int *decrypt_len);
}; };
} // namespace armour } // namespace armour

View File

@ -54,14 +54,22 @@ PrivateKey::PrivateKey(EVP_PKEY *evpKey) { evpPrivKey = evpKey; }
PrivateKey::~PrivateKey() { EVP_PKEY_free(evpPrivKey); } PrivateKey::~PrivateKey() { EVP_PKEY_free(evpPrivKey); }
int PrivateKey::GetPrivateBytes(size_t *len, unsigned char *privKeyBytes) { int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) {
if (privKeyBytes == nullptr || len <= 0) {
MS_LOG(ERROR) << "input privKeyBytes invalid.";
return -1;
}
if (!EVP_PKEY_get_raw_private_key(evpPrivKey, privKeyBytes, len)) { if (!EVP_PKEY_get_raw_private_key(evpPrivKey, privKeyBytes, len)) {
return -1; return -1;
} }
return 0; return 0;
} }
int PrivateKey::GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) { int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) {
if (pubKeyBytes == nullptr || len <= 0) {
MS_LOG(ERROR) << "input pubKeyBytes invalid.";
return -1;
}
if (!EVP_PKEY_get_raw_public_key(evpPrivKey, pubKeyBytes, len)) { if (!EVP_PKEY_get_raw_public_key(evpPrivKey, pubKeyBytes, len)) {
return -1; return -1;
} }
@ -69,7 +77,19 @@ int PrivateKey::GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) {
} }
int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned char *salt, int salt_len, int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned char *salt, int salt_len,
unsigned char *exchangeKey) { uint8_t *exchangeKey) {
if (peerPublicKey == nullptr) {
MS_LOG(ERROR) << "peerPublicKey is nullptr.";
return -1;
}
if (key_len != KEY_LEN || exchangeKey == nullptr) {
MS_LOG(ERROR) << "exchangeKey is nullptr or input key_len is incorrect.";
return -1;
}
if (salt == nullptr || salt_len != SALT_LEN) {
MS_LOG(ERROR) << "input salt in invalid.";
return -1;
}
EVP_PKEY_CTX *ctx; EVP_PKEY_CTX *ctx;
size_t len = 0; size_t len = 0;
ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL); ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL);
@ -79,30 +99,38 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c
} }
if (EVP_PKEY_derive_init(ctx) <= 0) { if (EVP_PKEY_derive_init(ctx) <= 0) {
MS_LOG(ERROR) << "EVP_PKEY_derive_init failed!"; MS_LOG(ERROR) << "EVP_PKEY_derive_init failed!";
EVP_PKEY_CTX_free(ctx);
return -1; return -1;
} }
if (EVP_PKEY_derive_set_peer(ctx, peerPublicKey->evpPubKey) <= 0) { if (EVP_PKEY_derive_set_peer(ctx, peerPublicKey->evpPubKey) <= 0) {
MS_LOG(ERROR) << "EVP_PKEY_derive_set_peer failed!"; MS_LOG(ERROR) << "EVP_PKEY_derive_set_peer failed!";
EVP_PKEY_CTX_free(ctx);
return -1; return -1;
} }
unsigned char *secret; unsigned char *secret;
if (EVP_PKEY_derive(ctx, NULL, &len) <= 0) { if (EVP_PKEY_derive(ctx, NULL, &len) <= 0) {
MS_LOG(ERROR) << "get derive key size failed!"; MS_LOG(ERROR) << "get derive key size failed!";
EVP_PKEY_CTX_free(ctx);
return -1; return -1;
} }
secret = (unsigned char *)OPENSSL_malloc(len); secret = (unsigned char *)OPENSSL_malloc(len);
if (!secret) { if (!secret) {
MS_LOG(ERROR) << "malloc secret memory failed!"; MS_LOG(ERROR) << "malloc secret memory failed!";
EVP_PKEY_CTX_free(ctx);
return -1; return -1;
} }
if (EVP_PKEY_derive(ctx, secret, &len) <= 0) { if (EVP_PKEY_derive(ctx, secret, &len) <= 0) {
MS_LOG(ERROR) << "derive key failed!"; MS_LOG(ERROR) << "derive key failed!";
OPENSSL_free(secret);
EVP_PKEY_CTX_free(ctx);
return -1; return -1;
} }
if (!PKCS5_PBKDF2_HMAC(reinterpret_cast<char *>(secret), len, salt, salt_len, ITERATION, EVP_sha256(), key_len,
if (!PKCS5_PBKDF2_HMAC((char *)secret, len, salt, salt_len, ITERATION, EVP_sha256(), key_len, exchangeKey)) { exchangeKey)) {
OPENSSL_free(secret);
EVP_PKEY_CTX_free(ctx);
return -1; return -1;
} }
OPENSSL_free(secret); OPENSSL_free(secret);
@ -118,9 +146,11 @@ PrivateKey *KeyAgreement::GeneratePrivKey() {
return NULL; return NULL;
} }
if (EVP_PKEY_keygen_init(pctx) <= 0) { if (EVP_PKEY_keygen_init(pctx) <= 0) {
EVP_PKEY_CTX_free(pctx);
return NULL; return NULL;
} }
if (EVP_PKEY_keygen(pctx, &evpKey) <= 0) { if (EVP_PKEY_keygen(pctx, &evpKey) <= 0) {
EVP_PKEY_CTX_free(pctx);
return NULL; return NULL;
} }
EVP_PKEY_CTX_free(pctx); EVP_PKEY_CTX_free(pctx);
@ -131,14 +161,30 @@ PrivateKey *KeyAgreement::GeneratePrivKey() {
PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) { PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
unsigned char *pubKeyBytes; unsigned char *pubKeyBytes;
size_t len = 0; size_t len = 0;
if (privKey == nullptr) {
return NULL;
}
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, NULL, &len)) { if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, NULL, &len)) {
return NULL; return NULL;
} }
pubKeyBytes = (unsigned char *)OPENSSL_malloc(len); pubKeyBytes = reinterpret_cast<uint8_t *>(OPENSSL_malloc(len));
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, pubKeyBytes, &len)) { if (!pubKeyBytes) {
MS_LOG(ERROR) << "malloc secret memory failed!";
return NULL;
}
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, pubKeyBytes, &len)) {
MS_LOG(ERROR) << "EVP_PKEY_get_raw_public_key failed!";
OPENSSL_free(pubKeyBytes);
return NULL;
}
EVP_PKEY *evp_pubKey =
EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, reinterpret_cast<uint8_t *>(pubKeyBytes), len);
if (evp_pubKey == NULL) {
MS_LOG(ERROR) << "EVP_PKEY_new_raw_public_key failed!";
OPENSSL_free(pubKeyBytes);
return NULL; return NULL;
} }
EVP_PKEY *evp_pubKey = EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, (unsigned char *)pubKeyBytes, len);
OPENSSL_free(pubKeyBytes); OPENSSL_free(pubKeyBytes);
PublicKey *pubKey = new PublicKey(evp_pubKey); PublicKey *pubKey = new PublicKey(evp_pubKey);
return pubKey; return pubKey;
@ -147,6 +193,7 @@ PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) { PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) {
EVP_PKEY *evp_Key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, NULL, data, len); EVP_PKEY *evp_Key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, NULL, data, len);
if (evp_Key == NULL) { if (evp_Key == NULL) {
MS_LOG(ERROR) << "create evp_Key from raw bytes failed!";
return NULL; return NULL;
} }
PrivateKey *privKey = new PrivateKey(evp_Key); PrivateKey *privKey = new PrivateKey(evp_Key);
@ -164,7 +211,11 @@ PublicKey *KeyAgreement::FromPublicBytes(unsigned char *data, int len) {
} }
int KeyAgreement::ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey, int key_len, int KeyAgreement::ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey, int key_len,
const unsigned char *salt, int salt_len, unsigned char *exchangeKey) { const unsigned char *salt, int salt_len, uint8_t *exchangeKey) {
if (privKey == nullptr) {
MS_LOG(ERROR) << "privKey is nullptr!";
return -1;
}
return privKey->Exchange(peerPublicKey, key_len, salt, salt_len, exchangeKey); return privKey->Exchange(peerPublicKey, key_len, salt, salt_len, exchangeKey);
} }
#endif #endif

View File

@ -24,7 +24,8 @@
#endif #endif
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#define KEK_KEY_LEN 32 #define KEY_LEN 32
#define SALT_LEN 32
#define ITERATION 10000 #define ITERATION 10000
namespace mindspore { namespace mindspore {

View File

@ -14,44 +14,39 @@
* limitations under the License. * limitations under the License.
*/ */
#include "fl/armour/secure_protocol/random.h" #include "fl/armour/secure_protocol/masking.h"
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
Random::Random(size_t init_seed) { generator.seed(init_seed); }
Random::~Random() {}
#ifdef _WIN32 #ifdef _WIN32
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) { int Masking::GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *seed, int seed_len,
MS_LOG(ERROR) << "Unsupported feature in Windows platform."; const uint8_t *ivec, int ivec_size) {
return -1;
}
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
MS_LOG(ERROR) << "Unsupported feature in Windows platform."; MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
return -1; return -1;
} }
#else #else
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) { int Masking::GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *secret, int secret_len,
int retval = RAND_priv_bytes(secret, num_bytes); const uint8_t *ivec, int ivec_size) {
return retval; if ((secret_len != KEY_LENGTH_16 && secret_len != KEY_LENGTH_32) || secret == NULL) {
} MS_LOG(ERROR) << "secret is invalid!";
return -1;
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) { }
if (seed_len != 16 && seed_len != 32) { if (noise == NULL || noise_len <= 0) {
MS_LOG(ERROR) << "seed length must be 16 or 32!"; MS_LOG(ERROR) << "noise is invalid!";
return -1;
}
if (ivec == NULL || ivec_size != AES_IV_SIZE) {
MS_LOG(ERROR) << "ivec is invalid!";
return -1; return -1;
} }
int size = noise_len * sizeof(int); int size = noise_len * sizeof(int);
std::vector<unsigned char> data(size, 0); std::vector<uint8_t> data(size, 0);
std::vector<unsigned char> encrypt_data(size, 0); std::vector<uint8_t> encrypt_data(size, 0);
std::vector<unsigned char> ivec(INIT_VEC_SIZE, 0);
int encrypt_len = 0; int encrypt_len = 0;
AESEncrypt encrypt(seed, seed_len, ivec.data(), INIT_VEC_SIZE, AES_CTR); AESEncrypt encrypt(secret, secret_len, ivec, AES_IV_SIZE, AES_CTR);
if (encrypt.EncryptData(data.data(), size, encrypt_data.data(), &encrypt_len) != 0) { if (encrypt.EncryptData(data.data(), size, encrypt_data.data(), &encrypt_len) != 0) {
MS_LOG(ERROR) << "call encryptData fail!"; MS_LOG(ERROR) << "call AES-CTR failed!";
return -1; return -1;
} }

View File

@ -19,27 +19,15 @@
#include <random> #include <random>
#include <vector> #include <vector>
#ifndef _WIN32
#include <openssl/rand.h>
#endif
#include "fl/armour/secure_protocol/encrypt.h" #include "fl/armour/secure_protocol/encrypt.h"
namespace mindspore { namespace mindspore {
namespace armour { namespace armour {
#define RANDOM_LEN 8 class Masking {
class Random {
public: public:
explicit Random(size_t init_seed); static int GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *secret, int secret_len,
~Random(); const uint8_t *ivec, int ivec_size);
// use openssl RAND_priv_bytes
static int GetRandomBytes(unsigned char *secret, int num_bytes);
static int RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len);
private:
std::default_random_engine generator;
}; };
} // namespace armour } // namespace armour
} // namespace mindspore } // namespace mindspore

View File

@ -62,6 +62,11 @@ struct RoundConfig {
struct CipherConfig { struct CipherConfig {
float share_secrets_ratio = 1.0; float share_secrets_ratio = 1.0;
uint64_t cipher_time_window = 300000; uint64_t cipher_time_window = 300000;
size_t exchange_keys_threshold = 0;
size_t get_keys_threshold = 0;
size_t share_secrets_threshold = 0;
size_t get_secrets_threshold = 0;
size_t client_list_threshold = 0;
size_t reconstruct_secrets_threshold = 0; size_t reconstruct_secrets_threshold = 0;
}; };
@ -207,8 +212,11 @@ constexpr auto kCtxClientNoises = "clients_noises";
constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares"; constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";
constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares"; constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares";
constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list"; constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list";
constexpr auto kCtxGetSecretsClientList = "get_secrets_client_list";
constexpr auto kCtxReconstructClientList = "reconstruct_client_list"; constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list"; constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list";
constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size"; constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
constexpr auto kCtxCipherPrimer = "cipher_primer"; constexpr auto kCtxCipherPrimer = "cipher_primer";

View File

@ -41,108 +41,102 @@ void ClientListKernel::InitKernel(size_t) {
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req, bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
std::shared_ptr<server::FBBuilder> fbb) { std::shared_ptr<server::FBBuilder> fbb) {
bool response = false;
std::vector<string> client_list; std::vector<string> client_list;
std::vector<string> empty_client_list;
std::string fl_id = get_clients_req->fl_id()->str(); std::string fl_id = get_clients_req->fl_id()->str();
int32_t iter_client = (size_t)get_clients_req->iteration();
if (iter_num != (size_t)iter_client) { if (!LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
MS_LOG(ERROR) << "ClientListKernel iteration invalid. servertime is " << iter_num; MS_LOG(ERROR) << "update_model_client_threshold is not set.";
MS_LOG(ERROR) << "ClientListKernel iteration invalid. clienttime is " << iter_client; BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_threshold is not set.",
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list, empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
std::to_string(CURRENT_TIME_MILLI.count()), iter_num); return false;
} else {
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
uint64_t update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
}
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) { // client in client_list.
if (static_cast<uint64_t>(client_list_pb.fl_id_size()) >= update_model_client_num) {
MS_LOG(INFO) << "send clients_list succeed!";
MS_LOG(INFO) << "UpdateModel client list: ";
for (size_t i = 0; i < client_list.size(); ++i) {
MS_LOG(INFO) << " fl_id : " << client_list[i];
}
MS_LOG(INFO) << "update_model_client_num: " << update_model_client_num;
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
response = true;
} else {
MS_LOG(INFO) << "The server is not ready. update_model_client_need_num: " << update_model_client_num;
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
/*for (size_t i = 0; i < std::min(client_list.size(), size_t(2)); ++i) {
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
}
for (size_t i = client_list.size() - size_t(1); i > std::max(client_list.size() - size_t(2), size_t(0));
--i) {
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
}*/
int count_tmp = 0;
for (size_t i = 0; i < cipher_init_->get_model_num_need_; ++i) {
size_t j = 0;
for (; j < client_list.size(); ++j) {
if (("f" + std::to_string(i)) == client_list[j]) break;
}
if (j >= client_list.size()) {
count_tmp++;
MS_LOG(INFO) << " no client_list fl_id : " << i;
if (count_tmp > 3) break;
}
}
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
}
}
if (response) {
DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str());
}
} else {
MS_LOG(ERROR) << "update_model_client_num is zero.";
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_num is zero.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
}
} }
return response; uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
}
if (static_cast<uint64_t>(client_list.size()) < update_model_client_needed) {
MS_LOG(INFO) << "The server is not ready. update_model_client_needed: " << update_model_client_needed;
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false;
}
if (find(client_list.begin(), client_list.end(), fl_id) == client_list.end()) { // client not in update model clients
std::string reason = "fl_id: " + fl_id + " is not in the update_model_clients";
MS_LOG(INFO) << reason;
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false;
}
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetUpdateModelClientList, fl_id);
if (!retcode_client) {
std::string reason = "update get update model clients failed";
MS_LOG(ERROR) << reason;
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false;
}
if (!DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str())) {
std::string reason = "Counting for get user list request failed. Please retry later.";
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
return true;
}
MS_LOG(INFO) << "send clients_list succeed!";
MS_LOG(INFO) << "UpdateModel client list: ";
for (size_t i = 0; i < client_list.size(); ++i) {
MS_LOG(INFO) << " fl_id : " << client_list[i];
}
MS_LOG(INFO) << "update_model_client_needed: " << update_model_client_needed;
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return true;
} }
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration; MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
clock_t start_time = clock(); clock_t start_time = clock();
std::vector<string> client_list; if (inputs.size() != 1 || outputs.size() != 1) {
if (inputs.size() != 1) { std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << "ClientListKernel needs 1 input,but got " << inputs.size(); MS_LOG(ERROR) << reason;
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel input num not match", client_list, return false;
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ClientListKernel needs 1 output,but got " << outputs.size();
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel output num not match", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "ClientListKernel num is enough", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
void *req_data = inputs[0]->addr;
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
if (get_clients_req == nullptr || fbb == nullptr) {
MS_LOG(ERROR) << "GetClientList is nullptr or ClientListRsp builder is nullptr.";
BuildClientListRsp(fbb, schema::ResponseCode_RequestError,
"GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
DealClient(iter_num, get_clients_req, fbb);
}
}
} }
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
std::vector<string> client_list;
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
int32_t iter_client = (size_t)get_clients_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "client list iteration number is invalid: server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
}
DealClient(iter_num, get_clients_req, fbb);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock(); clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
@ -164,30 +158,22 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli
const int iteration) { const int iteration) {
auto rsp_reason = client_list_resp_builder->CreateString(reason); auto rsp_reason = client_list_resp_builder->CreateString(reason);
auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time); auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time);
if (clients.size() > 0) { std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector; for (auto client : clients) {
for (auto client : clients) { auto client_fb = client_list_resp_builder->CreateString(client);
auto client_fb = client_list_resp_builder->CreateString(client); clients_vector.push_back(client_fb);
clients_vector.push_back(client_fb); MS_LOG(WARNING) << "update client list: ";
} MS_LOG(WARNING) << client;
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_clients(clients_fb);
rsp_builder.add_iteration(iteration);
rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys);
} else {
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_iteration(iteration);
rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys);
} }
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_clients(clients_fb);
rsp_builder.add_iteration(iteration);
rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys);
return; return;
} }

View File

@ -38,9 +38,36 @@ void ExchangeKeysKernel::InitKernel(size_t) {
cipher_key_ = &armour::CipherKeys::GetInstance(); cipher_key_ = &armour::CipherKeys::GetInstance();
} }
bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for exchangeKey is enough. Please retry later.";
cipher_key_->BuildExchangeKeysRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
MS_LOG(WARNING) << reason;
return true;
}
return false;
}
bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestExchangeKeys *exchange_keys_req,
const int iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(exchange_keys_req, false);
if (!DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str())) {
std::string reason = "Counting for exchange kernel request failed. Please retry later.";
cipher_key_->BuildExchangeKeysRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>(); MS_LOG(INFO) << "Launching ExchangeKey kernel.";
bool response = false; bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
@ -48,46 +75,49 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
<< total_duration; << total_duration;
clock_t start_time = clock(); clock_t start_time = clock();
if (inputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 input,but got " << inputs.size(); std::string reason = "inputs or outputs size is invalid.";
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel input num not match", MS_LOG(ERROR) << reason;
return false;
}
void *req_data = inputs[0]->addr;
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
if (ReachThresholdForExchangeKeys(fbb, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
int32_t iter_client = (size_t)exchange_keys_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "ExchangeKeys iteration number is invalid: server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num); std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else if (outputs.size() != 1) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 output,but got " << outputs.size(); return true;
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel output num not match", }
std::to_string(CURRENT_TIME_MILLI.count()), iter_num); response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
} else { if (!response) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { MS_LOG(WARNING) << "update exchange keys is failed.";
MS_LOG(ERROR) << "Current amount for ExchangeKeysKernel is enough."; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, return true;
"Current amount for ExchangeKeysKernel is enough.", }
std::to_string(CURRENT_TIME_MILLI.count()), iter_num); if (!CountForExchangeKeys(fbb, exchange_keys_req, iter_num)) {
} else { MS_LOG(ERROR) << "count for exchange keys failed.";
void *req_data = inputs[0]->addr; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
const schema::RequestExchangeKeys *exchange_keys_req = return true;
flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
int32_t iter_client = (size_t)exchange_keys_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "ExchangeKeysKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
response =
cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
if (response) {
DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str());
}
}
}
} }
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock(); clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration; MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
if (!response) {
MS_LOG(INFO) << "ExchangeKeysKernel response is false.";
}
return true; return true;
} }

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#include <vector> #include <vector>
#include <string>
#include <memory>
#include "fl/server/common.h" #include "fl/server/common.h"
#include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/kernel/round/round_kernel_factory.h"
@ -41,6 +43,9 @@ class ExchangeKeysKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
size_t iteration_time_window_; size_t iteration_time_window_;
armour::CipherKeys *cipher_key_; armour::CipherKeys *cipher_key_;
bool ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num);
bool CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestExchangeKeys *exchange_keys_req,
const int iter_num);
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -37,9 +37,23 @@ void GetKeysKernel::InitKernel(size_t) {
cipher_key_ = &armour::CipherKeys::GetInstance(); cipher_key_ = &armour::CipherKeys::GetInstance();
} }
bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
const int iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(get_keys_req, false);
if (!DistributedCountService::GetInstance().Count(name_, get_keys_req->fl_id()->str())) {
std::string reason = "Counting for getkeys kernel request failed. Please retry later.";
cipher_key_->BuildGetKeysRsp(
fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), false);
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>(); MS_LOG(INFO) << "Launching GetKeys kernel.";
bool response = false; bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
@ -47,44 +61,48 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
<< total_duration; << total_duration;
clock_t start_time = clock(); clock_t start_time = clock();
if (inputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "GetKeysKernel needs 1 input,but got " << inputs.size(); std::string reason = "inputs or outputs size is invalid.";
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num, MS_LOG(ERROR) << reason;
std::to_string(CURRENT_TIME_MILLI.count()), false); return false;
} else if (outputs.size() != 1) { }
MS_LOG(ERROR) << "GetKeysKernel needs 1 output,but got " << outputs.size();
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num, std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
std::to_string(CURRENT_TIME_MILLI.count()), false); void *req_data = inputs[0]->addr;
} else { if (fbb == nullptr || req_data == nullptr) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough."; MS_LOG(ERROR) << reason;
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num, return false;
std::to_string(CURRENT_TIME_MILLI.count()), false); }
} else {
void *req_data = inputs[0]->addr; if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data); MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
int32_t iter_client = (size_t)get_exchange_keys_req->iteration(); }
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
<< ". client request iteration is " << iter_client; int32_t iter_client = (size_t)get_exchange_keys_req->iteration();
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num, if (iter_num != (size_t)iter_client) {
std::to_string(CURRENT_TIME_MILLI.count()), false); MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
} else { << ". client request iteration is " << iter_client;
response = cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb); std::to_string(CURRENT_TIME_MILLI.count()), false);
if (response) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
DistributedCountService::GetInstance().Count(name_, get_exchange_keys_req->fl_id()->str()); return true;
} }
} response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
} if (!response) {
MS_LOG(WARNING) << "get public keys is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForGetKeys(fbb, get_exchange_keys_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
} }
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
clock_t end_time = clock(); clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration; MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
if (!response) {
MS_LOG(INFO) << "GetKeysKernel response is false.";
}
return true; return true;
} }

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
#include <vector> #include <vector>
#include <string>
#include <memory>
#include "fl/server/common.h" #include "fl/server/common.h"
#include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/kernel/round/round_kernel_factory.h"
@ -41,6 +43,8 @@ class GetKeysKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
size_t iteration_time_window_; size_t iteration_time_window_;
armour::CipherKeys *cipher_key_; armour::CipherKeys *cipher_key_;
bool CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
const int iter_num);
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -39,54 +39,72 @@ void GetSecretsKernel::InitKernel(size_t) {
cipher_share_ = &armour::CipherShares::GetInstance(); cipher_share_ = &armour::CipherShares::GetInstance();
} }
bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
const schema::GetShareSecrets *get_secrets_req, const int iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false);
if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) {
std::string reason = "Counting for get secrets kernel request failed. Please retry later.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), nullptr);
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
bool response = false; bool response = false;
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count()); std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is " MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
<< total_duration; << total_duration;
clock_t start_time = clock(); clock_t start_time = clock();
if (inputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "GetSecretsKernel needs 1 input,but got " << inputs.size(); std::string reason = "inputs or outputs size is invalid.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0); MS_LOG(ERROR) << reason;
} else if (outputs.size() != 1) { return false;
MS_LOG(ERROR) << "GetSecretsKernel needs 1 output,but got " << outputs.size();
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
} else {
void *req_data = inputs[0]->addr;
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
int32_t iter_client = (size_t)get_secrets_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
} else {
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
if (response) {
DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str());
}
}
}
} }
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
int32_t iter_client = (size_t)get_secrets_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
}
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
if (!response) {
MS_LOG(WARNING) << "get secret shares is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForGetSecrets(fbb, get_secrets_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock(); clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration; MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
if (!response) {
MS_LOG(INFO) << "GetSecretsKernel response is false.";
}
return true; return true;
} }

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#include <vector> #include <vector>
#include <memory>
#include "fl/server/common.h" #include "fl/server/common.h"
#include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/kernel/round/round_kernel_factory.h"
@ -41,6 +42,8 @@ class GetSecretsKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
size_t iteration_time_window_; size_t iteration_time_window_;
armour::CipherShares *cipher_share_; armour::CipherShares *cipher_share_;
bool CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::GetShareSecrets *get_secrets_req,
const int iter_num);
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -52,7 +52,6 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
bool response = false; bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
// MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); // MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
@ -61,71 +60,58 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
<< total_duration; << total_duration;
clock_t start_time = clock(); clock_t start_time = clock();
if (inputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size(); MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError, return false;
"ReconstructSecretsKernel input num not match.", iter_num, }
std::to_string(CURRENT_TIME_MILLI.count()));
} else if (outputs.size() != 1) { std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 output, but got " << outputs.size(); void *req_data = inputs[0]->addr;
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
"ReconstructSecretsKernel output num not match.", iter_num, if (fbb == nullptr || req_data == nullptr) {
std::to_string(CURRENT_TIME_MILLI.count())); std::string reason = "FBBuilder builder or req_data is nullptr.";
} else { MS_LOG(ERROR) << reason;
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { return false;
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough."; }
// get client list from memory server.
std::vector<string> update_model_clients;
const PBMetadata update_model_clients_pb_out =
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
for (int i = 0; i < update_model_clients_pb.fl_id_size(); ++i) {
update_model_clients.push_back(update_model_clients_pb.fl_id(i));
}
const schema::SendReconstructSecret *reconstruct_secret_req =
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
std::string fl_id = reconstruct_secret_req->fl_id()->str();
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
// client in get update model client list.
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED,
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
} else {
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ReconstructSecretsKernel is enough.", iter_num, "Current amount for ReconstructSecretsKernel is enough.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count())); std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
void *req_data = inputs[0]->addr;
const schema::SendReconstructSecret *reconstruct_secret_req =
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
// get client list from memory server.
std::vector<string> client_list;
uint64_t update_model_client_num = 0;
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
} else {
MS_LOG(ERROR) << "update_model_client_num is zero.";
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
"update_model_client_num is zero.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const PBMetadata client_list_pb_out =
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
int client_list_actual_size = client_list_pb.fl_id_size();
if (client_list_actual_size < 0) {
client_list_actual_size = 0;
}
if (static_cast<uint64_t>(client_list_actual_size) < update_model_client_num) {
MS_LOG(INFO) << "ReconstructSecretsKernel : client list is not ready " << inputs.size();
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SucNotReady,
"ReconstructSecretsKernel : client list is not ready", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
}
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
reconstruct_secret_req, fbb, client_list);
if (response) {
// MS_LOG(INFO) << "start ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
// MS_LOG(INFO) << "end ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
} }
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
} }
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
reconstruct_secret_req, fbb, update_model_clients);
if (response) {
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
}
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough.";
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock(); clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
@ -142,7 +128,6 @@ void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::
while (!Executor::GetInstance().IsAllWeightAggregationDone()) { while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5)); std::this_thread::sleep_for(std::chrono::milliseconds(5));
} }
MS_LOG(INFO) << "start unmask"; MS_LOG(INFO) << "start unmask";
while (!Executor::GetInstance().Unmask()) { while (!Executor::GetInstance().Unmask()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5)); std::this_thread::sleep_for(std::chrono::milliseconds(5));

View File

@ -36,58 +36,75 @@ void ShareSecretsKernel::InitKernel(size_t) {
cipher_share_ = &armour::CipherShares::GetInstance(); cipher_share_ = &armour::CipherShares::GetInstance();
} }
bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestShareSecrets *share_secrets_req,
const int iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(share_secrets_req, false);
if (!DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str())) {
std::string reason = "Counting for share secret kernel request failed. Please retry later.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
return false;
}
return true;
}
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
bool response = false; bool response = false;
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is " MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is "
<< total_duration; << total_duration;
clock_t start_time = clock(); clock_t start_time = clock();
if (inputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 input,but got " << inputs.size(); std::string reason = "inputs or outputs size is invalid.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError, "ShareSecretsKernel input num not match", MS_LOG(ERROR) << reason;
std::to_string(CURRENT_TIME_MILLI.count()), iter_num); return false;
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 output,but got " << outputs.size();
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError,
"ShareSecretsKernel output num not match",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ShareSecretsKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
void *req_data = inputs[0]->addr;
const schema::RequestShareSecrets *share_secrets_req =
flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
size_t iter_client = (size_t)share_secrets_req->iteration();
if (iter_num != iter_client) {
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
response =
cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
if (response) {
DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str());
}
}
}
} }
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ShareSecretsKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
size_t iter_client = (size_t)share_secrets_req->iteration();
if (iter_num != iter_client) {
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
response = cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
if (!response) {
MS_LOG(WARNING) << "update secret shares is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!CountForShareSecrets(fbb, share_secrets_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock(); clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration; MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
if (!response) {
MS_LOG(INFO) << "share_secrets_kernel response is false.";
}
return true; return true;
} }

View File

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H #define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#include <vector> #include <vector>
#include <string>
#include <memory>
#include "fl/server/common.h" #include "fl/server/common.h"
#include "fl/server/executor.h" #include "fl/server/executor.h"
#include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel.h"
@ -41,6 +43,8 @@ class ShareSecretsKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
size_t iteration_time_window_; size_t iteration_time_window_;
armour::CipherShares *cipher_share_; armour::CipherShares *cipher_share_;
bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req,
const int iter_num);
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -157,14 +157,31 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas); PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas(); FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
std::string update_model_fl_id = update_model_req->fl_id()->str(); std::string update_model_fl_id = update_model_req->fl_id()->str();
MS_LOG(INFO) << "Update model for fl id " << update_model_fl_id; MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id;
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) { if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later."; if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
BuildUpdateModelRsp( std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
fbb, schema::ResponseCode_OutOfTime, reason, BuildUpdateModelRsp(
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp))); fbb, schema::ResponseCode_OutOfTime, reason,
MS_LOG(ERROR) << reason; std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
return ResultCode::kSuccessAndReturn; MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
}
} else {
std::vector<std::string> get_secrets_clients;
#ifdef ENABLE_ARMOUR
mindspore::armour::CipherMetaStorage cipher_meta_storage;
cipher_meta_storage.GetClientListFromServer(fl::server::kCtxGetSecretsClientList, &get_secrets_clients);
#endif
if (find(get_secrets_clients.begin(), get_secrets_clients.end(), update_model_fl_id) ==
get_secrets_clients.end()) { // the client not in get_secrets_clients
std::string reason = "fl_id: " + update_model_fl_id + " is not in get_secrets_clients. Please retry later.";
BuildUpdateModelRsp(
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return ResultCode::kSuccessAndReturn;
}
} }
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size(); size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();

View File

@ -25,6 +25,9 @@
#include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/kernel/round/round_kernel_factory.h"
#include "fl/server/executor.h" #include "fl/server/executor.h"
#ifdef ENABLE_ARMOUR
#include "fl/armour/cipher/cipher_meta_storage.h"
#endif
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {

View File

@ -213,53 +213,24 @@ void Server::InitIteration() {
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) { if (encrypt_type == ps::kPWEncryptType) {
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count; cipher_exchange_keys_cnt_ = cipher_config_.exchange_keys_threshold;
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0; cipher_get_keys_cnt_ = cipher_config_.get_keys_threshold;
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio; cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count; cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count; cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold; cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
cipher_time_window_ = cipher_config_.cipher_time_window; cipher_time_window_ = cipher_config_.cipher_time_window;
MS_LOG(INFO) << "Initializing cipher:"; MS_LOG(INFO) << "Initializing cipher:";
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_ MS_LOG(INFO) << " cipher_exchange_keys_cnt_: " << cipher_exchange_keys_cnt_
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_ << " cipher_get_keys_cnt_: " << cipher_get_keys_cnt_
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_; << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_ MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
<< " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_ << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
<< " cipher_time_window_: " << cipher_time_window_ << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_; << " cipher_time_window_: " << cipher_time_window_;
std::shared_ptr<Round> exchange_keys_round =
std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
MS_EXCEPTION_IF_NULL(exchange_keys_round);
iteration_->AddRound(exchange_keys_round);
std::shared_ptr<Round> get_keys_round =
std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
MS_EXCEPTION_IF_NULL(get_keys_round);
iteration_->AddRound(get_keys_round);
std::shared_ptr<Round> share_secrets_round =
std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
MS_EXCEPTION_IF_NULL(share_secrets_round);
iteration_->AddRound(share_secrets_round);
std::shared_ptr<Round> get_secrets_round =
std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
MS_EXCEPTION_IF_NULL(get_secrets_round);
iteration_->AddRound(get_secrets_round);
std::shared_ptr<Round> get_clientlist_round =
std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_);
MS_EXCEPTION_IF_NULL(get_clientlist_round);
iteration_->AddRound(get_clientlist_round);
std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>(
"reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_);
MS_EXCEPTION_IF_NULL(reconstruct_secrets_round);
iteration_->AddRound(reconstruct_secrets_round);
MS_LOG(INFO) << "Cipher rounds has been added.";
} }
#endif #endif
@ -314,8 +285,8 @@ void Server::InitCipher() {
param.dp_eps = dp_eps; param.dp_eps = dp_eps;
param.dp_norm_clip = dp_norm_clip; param.dp_norm_clip = dp_norm_clip;
param.encrypt_type = encrypt_type; param.encrypt_type = encrypt_type;
cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_, cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_, cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
cipher_reconstruct_secrets_up_cnt_); cipher_reconstruct_secrets_up_cnt_);
#endif #endif
} }

View File

@ -80,7 +80,7 @@ class Server {
worker_num_(0), worker_num_(0),
fl_server_port_(0), fl_server_port_(0),
cipher_initial_client_cnt_(0), cipher_initial_client_cnt_(0),
cipher_exchange_secrets_cnt_(0), cipher_exchange_keys_cnt_(0),
cipher_share_secrets_cnt_(0), cipher_share_secrets_cnt_(0),
cipher_get_clientlist_cnt_(0), cipher_get_clientlist_cnt_(0),
cipher_reconstruct_secrets_up_cnt_(0), cipher_reconstruct_secrets_up_cnt_(0),
@ -197,8 +197,10 @@ class Server {
uint32_t worker_num_; uint32_t worker_num_;
uint16_t fl_server_port_; uint16_t fl_server_port_;
size_t cipher_initial_client_cnt_; size_t cipher_initial_client_cnt_;
size_t cipher_exchange_secrets_cnt_; size_t cipher_exchange_keys_cnt_;
size_t cipher_get_keys_cnt_;
size_t cipher_share_secrets_cnt_; size_t cipher_share_secrets_cnt_;
size_t cipher_get_secrets_cnt_;
size_t cipher_get_clientlist_cnt_; size_t cipher_get_clientlist_cnt_;
size_t cipher_reconstruct_secrets_up_cnt_; size_t cipher_reconstruct_secrets_up_cnt_;
size_t cipher_reconstruct_secrets_down_cnt_; size_t cipher_reconstruct_secrets_down_cnt_;

View File

@ -856,9 +856,33 @@ bool StartServerAction(const ResourcePtr &res) {
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio(); float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window(); uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold(); size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold() + 1;
fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshold}; size_t exchange_keys_threshold =
std::max(static_cast<size_t>(std::ceil(start_fl_job_threshold * share_secrets_ratio)), update_model_threshold);
size_t get_keys_threshold =
std::max(static_cast<size_t>(std::ceil(exchange_keys_threshold * share_secrets_ratio)), update_model_threshold);
size_t share_secrets_threshold =
std::max(static_cast<size_t>(std::ceil(get_keys_threshold * share_secrets_ratio)), update_model_threshold);
size_t get_secrets_threshold =
std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold);
size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)),
reconstruct_secrets_threshold);
#ifdef ENABLE_ARMOUR
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) {
MS_LOG(INFO) << "Add secure aggregation rounds.";
rounds_config.push_back({"exchangeKeys", true, cipher_time_window, true, exchange_keys_threshold});
rounds_config.push_back({"getKeys", true, cipher_time_window, true, get_keys_threshold});
rounds_config.push_back({"shareSecrets", true, cipher_time_window, true, share_secrets_threshold});
rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold});
rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
}
#endif
fl::server::CipherConfig cipher_config = {
share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,
share_secrets_threshold, get_secrets_threshold, client_list_threshold, reconstruct_secrets_threshold};
size_t executor_threshold = 0; size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {

View File

@ -126,6 +126,13 @@ message SharesPb {
message KeysPb { message KeysPb {
repeated bytes key = 1; repeated bytes key = 1;
string timestamp = 2;
int32 iter_num = 3;
bytes ind_iv = 4;
bytes pw_iv = 5;
bytes pw_salt = 6;
bytes signature = 7;
repeated string certificate_chain = 8;
} }
message Prime { message Prime {

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,12 +16,18 @@
package com.mindspore.flclient; package com.mindspore.flclient;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.I_VEC_LEN;
import static com.mindspore.flclient.LocalFLParameter.SALT_SIZE;
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.cipher.AESEncrypt; import com.mindspore.flclient.cipher.AESEncrypt;
import com.mindspore.flclient.cipher.BaseUtil; import com.mindspore.flclient.cipher.BaseUtil;
import com.mindspore.flclient.cipher.ClientListReq; import com.mindspore.flclient.cipher.ClientListReq;
import com.mindspore.flclient.cipher.KEYAgreement; import com.mindspore.flclient.cipher.KEYAgreement;
import com.mindspore.flclient.cipher.Random; import com.mindspore.flclient.cipher.Masking;
import com.mindspore.flclient.cipher.ReconstructSecretReq; import com.mindspore.flclient.cipher.ReconstructSecretReq;
import com.mindspore.flclient.cipher.ShareSecrets; import com.mindspore.flclient.cipher.ShareSecrets;
import com.mindspore.flclient.cipher.struct.ClientPublicKey; import com.mindspore.flclient.cipher.struct.ClientPublicKey;
@ -29,6 +35,7 @@ import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import com.mindspore.flclient.cipher.struct.EncryptShare; import com.mindspore.flclient.cipher.struct.EncryptShare;
import com.mindspore.flclient.cipher.struct.NewArray; import com.mindspore.flclient.cipher.struct.NewArray;
import com.mindspore.flclient.cipher.struct.ShareSecret; import com.mindspore.flclient.cipher.struct.ShareSecret;
import mindspore.schema.ClientShare; import mindspore.schema.ClientShare;
import mindspore.schema.GetExchangeKeys; import mindspore.schema.GetExchangeKeys;
import mindspore.schema.GetShareSecrets; import mindspore.schema.GetShareSecrets;
@ -40,22 +47,22 @@ import mindspore.schema.ResponseShareSecrets;
import mindspore.schema.ReturnExchangeKeys; import mindspore.schema.ReturnExchangeKeys;
import mindspore.schema.ReturnShareSecrets; import mindspore.schema.ReturnShareSecrets;
import java.io.UnsupportedEncodingException; import java.io.IOException;
import java.math.BigInteger; import java.math.BigInteger;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException; import java.security.SecureRandom;
import java.security.spec.InvalidKeySpecException;
import java.time.LocalDateTime;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME; /**
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN; * A class used for secure aggregation
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE; *
* @since 2021-8-27
*/
public class CipherClient { public class CipherClient {
private static final Logger LOGGER = Logger.getLogger(CipherClient.class.toString()); private static final Logger LOGGER = Logger.getLogger(CipherClient.class.toString());
private FLCommunication flCommunication; private FLCommunication flCommunication;
@ -63,129 +70,217 @@ public class CipherClient {
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private final int iteration; private final int iteration;
private int featureSize; private int featureSize;
private int t; private int minShareNum;
private List<byte[]> cKey = new ArrayList<>(); private List<byte[]> cKey = new ArrayList<>();
private List<byte[]> sKey = new ArrayList<>(); private List<byte[]> sKey = new ArrayList<>();
private byte[] bu; private byte[] bu;
private byte[] individualIv = new byte[I_VEC_LEN];
private byte[] pwIVec = new byte[I_VEC_LEN];
private byte[] pwSalt = new byte[SALT_SIZE];
private String nextRequestTime; private String nextRequestTime;
private Map<String, ClientPublicKey> clientPublicKeyList = new HashMap<String, ClientPublicKey>(); private Map<String, ClientPublicKey> clientPublicKeyList = new HashMap<String, ClientPublicKey>();
private Map<String, byte[]> sUVKeys = new HashMap<String, byte[]>(); private Map<String, byte[]> sUVKeys = new HashMap<String, byte[]>();
private Map<String, byte[]> cUVKeys = new HashMap<String, byte[]>(); private Map<String, byte[]> cUVKeys = new HashMap<String, byte[]>();
private List<EncryptShare> clientShareList = new ArrayList<>(); private List<EncryptShare> clientShareList = new ArrayList<>();
private List<EncryptShare> returnShareList = new ArrayList<>(); private List<EncryptShare> returnShareList = new ArrayList<>();
private float[] featureMask;
private List<String> u1ClientList = new ArrayList<>(); private List<String> u1ClientList = new ArrayList<>();
private List<String> u2UClientList = new ArrayList<>(); private List<String> u2UClientList = new ArrayList<>();
private List<String> u3ClientList = new ArrayList<>(); private List<String> u3ClientList = new ArrayList<>();
private List<DecryptShareSecrets> decryptShareSecretsList = new ArrayList<>(); private List<DecryptShareSecrets> decryptShareSecretsList = new ArrayList<>();
private byte[] prime; private byte[] prime;
private KEYAgreement keyAgreement = new KEYAgreement(); private KEYAgreement keyAgreement = new KEYAgreement();
private Random random = new Random(); private Masking masking = new Masking();
private ClientListReq clientListReq = new ClientListReq(); private ClientListReq clientListReq = new ClientListReq();
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq(); private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
private int retCode; private int retCode;
/**
* Construct function of cipherClient
*
* @param iter iteration number
* @param minSecretNum minimum secret shares number used for reconstruct secret
* @param prime prime value
* @param featureSize featureSize of network
*/
public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) { public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) {
flCommunication = FLCommunication.getInstance(); flCommunication = FLCommunication.getInstance();
this.iteration = iter; this.iteration = iter;
this.featureSize = featureSize; this.featureSize = featureSize;
this.t = minSecretNum; this.minShareNum = minSecretNum;
this.prime = prime; this.prime = prime;
this.featureMask = new float[this.featureSize];
} }
/**
* Set next request time
*
* @param nextRequestTime next request timestamp
*/
public void setNextRequestTime(String nextRequestTime) { public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime; this.nextRequestTime = nextRequestTime;
} }
public void setBU(byte[] bu) { /**
this.bu = bu; * Set client share list
} *
* @param clientShareList client share list
public void setClientShareList(List<EncryptShare> clientShareList) { */
private void setClientShareList(List<EncryptShare> clientShareList) {
this.clientShareList.clear(); this.clientShareList.clear();
this.clientShareList = clientShareList; this.clientShareList = clientShareList;
} }
/**
* get next request time
*
* @return next request time
*/
public String getNextRequestTime() { public String getNextRequestTime() {
return nextRequestTime; return nextRequestTime;
} }
/**
* get retCode
*
* @return retCode
*/
public int getRetCode() { public int getRetCode() {
return retCode; return retCode;
} }
public void genDHKeyPairs() { private FLClientStatus genDHKeyPairs() {
byte[] csk = keyAgreement.generatePrivateKey(); byte[] csk = keyAgreement.generatePrivateKey();
byte[] cpk = keyAgreement.generatePublicKey(csk); byte[] cpk = keyAgreement.generatePublicKey(csk);
if (cpk == null || cpk.length == 0) {
LOGGER.severe(Common.addTag("[genDHKeyPairs] the return byte[] <cpk> is null, please check!"));
return FLClientStatus.FAILED;
}
byte[] ssk = keyAgreement.generatePrivateKey(); byte[] ssk = keyAgreement.generatePrivateKey();
byte[] spk = keyAgreement.generatePublicKey(ssk); byte[] spk = keyAgreement.generatePublicKey(ssk);
if (spk == null || spk.length == 0) {
LOGGER.severe(Common.addTag("[genDHKeyPairs] the return byte[] <spk> is null, please check!"));
return FLClientStatus.FAILED;
}
this.cKey.clear();
this.sKey.clear();
this.cKey.add(cpk); this.cKey.add(cpk);
this.cKey.add(csk); this.cKey.add(csk);
this.sKey.add(spk); this.sKey.add(spk);
this.sKey.add(ssk); this.sKey.add(ssk);
return FLClientStatus.SUCCESS;
} }
public void genIndividualSecret() { private FLClientStatus genIndividualSecret() {
byte[] key = new byte[SEED_SIZE]; byte[] key = new byte[SEED_SIZE];
random.getRandomBytes(key); int tag = masking.getRandomBytes(key);
setBU(key); if (tag == -1) {
LOGGER.severe(Common.addTag("[genIndividualSecret] the return value is -1, please check!"));
return FLClientStatus.FAILED;
}
this.bu = key;
return FLClientStatus.SUCCESS;
} }
public List<ShareSecret> genSecretShares(byte[] secret) throws UnsupportedEncodingException { private List<ShareSecret> genSecretShares(byte[] secret) {
List<ShareSecret> shareSecretList = new ArrayList<>(); if (secret == null || secret.length == 0) {
LOGGER.severe(Common.addTag("[genSecretShares] the input argument <secret> is null"));
return new ArrayList<>();
}
int size = u1ClientList.size(); int size = u1ClientList.size();
ShareSecrets shamir = new ShareSecrets(t, size - 1); if (size <= 1) {
ShareSecrets.SecretShare[] shares = shamir.split(secret, prime); LOGGER.severe(Common.addTag("[genSecretShares] the size of u1ClientList is not valid: <= 1, it should be " +
int j = 0; "> 1"));
for (int i = 0; i < size; i++) { return new ArrayList<>();
String vFlID = u1ClientList.get(i); }
ShareSecrets shamir = new ShareSecrets(minShareNum, size - 1);
ShareSecrets.SecretShares[] shares = shamir.split(secret, prime);
if (shares == null || shares.length == 0) {
LOGGER.severe(Common.addTag("[genSecretShares] the return ShareSecrets.SecretShare[] is null, please " +
"check!"));
return new ArrayList<>();
}
int shareIndex = 0;
List<ShareSecret> shareSecretList = new ArrayList<>();
for (String vFlID : u1ClientList) {
if (localFLParameter.getFlID().equals(vFlID)) { if (localFLParameter.getFlID().equals(vFlID)) {
continue; continue;
} else {
ShareSecret shareSecret = new ShareSecret();
NewArray<byte[]> array = new NewArray<>();
int index = shares[j].getNum();
BigInteger intShare = shares[j].getShare();
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
array.setSize(share.length);
array.setArray(share);
shareSecret.setFlID(vFlID);
shareSecret.setShare(array);
shareSecret.setIndex(index);
shareSecretList.add(shareSecret);
j += 1;
} }
if (shareIndex >= shares.length) {
LOGGER.severe(Common.addTag("[genSecretShares] the shareIndex is out of range in array <shares>, " +
"please check!"));
return new ArrayList<>();
}
int index = shares[shareIndex].getNumber();
BigInteger intShare = shares[shareIndex].getShares();
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
NewArray<byte[]> array = new NewArray<>();
array.setSize(share.length);
array.setArray(share);
ShareSecret shareSecret = new ShareSecret();
shareSecret.setFlID(vFlID);
shareSecret.setShare(array);
shareSecret.setIndex(index);
shareSecretList.add(shareSecret);
shareIndex += 1;
} }
return shareSecretList; return shareSecretList;
} }
public void genEncryptExchangedKeys() throws InvalidKeySpecException, NoSuchAlgorithmException { private FLClientStatus genEncryptExchangedKeys() {
cUVKeys.clear(); cUVKeys.clear();
for (String key : clientPublicKeyList.keySet()) { for (String key : clientPublicKeyList.keySet()) {
ClientPublicKey curPublicKey = clientPublicKeyList.get(key); ClientPublicKey curPublicKey = clientPublicKeyList.get(key);
String vFlID = curPublicKey.getFlID(); String vFlID = curPublicKey.getFlID();
if (localFLParameter.getFlID().equals(vFlID)) { if (localFLParameter.getFlID().equals(vFlID)) {
continue; continue;
} else {
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
byte[] salt = new byte[0];
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
cUVKeys.put(vFlID, secret);
} }
if (cKey.size() < 2) {
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the size of cKey is not valid: < 2, it should " +
"be >= 2, please check!"));
return FLClientStatus.FAILED;
}
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
if (secret1 == null || secret1.length == 0) {
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the returned secret1 is null, please check!"));
return FLClientStatus.FAILED;
}
byte[] salt = new byte[0];
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
if (secret == null || secret.length == 0) {
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the returned secret is null, please check!"));
return FLClientStatus.FAILED;
}
cUVKeys.put(vFlID, secret);
} }
return FLClientStatus.SUCCESS;
} }
public void encryptShares() throws Exception { private FLClientStatus encryptShares() {
LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for RequestShareSecrets **************")); LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for " +
List<EncryptShare> encryptShareList = new ArrayList<>(); "RequestShareSecrets **************"));
// connect sSkUv, bUV, sIndex, indexB and then Encrypt them // connect sSkUv, bUV, sIndex, indexB and then Encrypt them
if (sKey.size() < 2) {
LOGGER.severe(Common.addTag("[encryptShares] the size of sKey is not valid: < 2, it should be >= 2, " +
"please check!"));
return FLClientStatus.FAILED;
}
List<ShareSecret> sSkUv = genSecretShares(sKey.get(1)); List<ShareSecret> sSkUv = genSecretShares(sKey.get(1));
if (sSkUv.isEmpty()) {
LOGGER.severe(Common.addTag("[encryptShares] the returned List<ShareSecret> sSkUv is empty, please " +
"check!"));
return FLClientStatus.FAILED;
}
List<ShareSecret> bUV = genSecretShares(bu); List<ShareSecret> bUV = genSecretShares(bu);
if (sSkUv.isEmpty()) {
LOGGER.severe(Common.addTag("[encryptShares] the returned List<ShareSecret> bUV is empty, please check!"));
return FLClientStatus.FAILED;
}
if (sSkUv.size() != bUV.size()) {
LOGGER.severe(Common.addTag("[encryptShares] the sSkUv.size() should be equal to bUV.size(), please " +
"check!"));
return FLClientStatus.FAILED;
}
List<EncryptShare> encryptShareList = new ArrayList<>();
for (int i = 0; i < bUV.size(); i++) { for (int i = 0; i < bUV.size(); i++) {
EncryptShare encryptShare = new EncryptShare();
NewArray<byte[]> array = new NewArray<>();
String vFlID = bUV.get(i).getFlID();
byte[] sShare = sSkUv.get(i).getShare().getArray(); byte[] sShare = sSkUv.get(i).getShare().getArray();
byte[] bShare = bUV.get(i).getShare().getArray(); byte[] bShare = bUV.get(i).getShare().getArray();
byte[] sIndex = BaseUtil.integer2byteArray(sSkUv.get(i).getIndex()); byte[] sIndex = BaseUtil.integer2byteArray(sSkUv.get(i).getIndex());
@ -200,53 +295,101 @@ public class CipherClient {
System.arraycopy(sShare, 0, allSecret, 4 + sIndex.length + bIndex.length, sShare.length); System.arraycopy(sShare, 0, allSecret, 4 + sIndex.length + bIndex.length, sShare.length);
System.arraycopy(bShare, 0, allSecret, 4 + sIndex.length + bIndex.length + sShare.length, bShare.length); System.arraycopy(bShare, 0, allSecret, 4 + sIndex.length + bIndex.length + sShare.length, bShare.length);
// encrypt: // encrypt:
byte[] iVecIn = new byte[IVEC_LEN]; String vFlID = bUV.get(i).getFlID();
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), iVecIn, "CBC"); if (!cUVKeys.containsKey(vFlID)) {
LOGGER.severe(Common.addTag("[encryptShares] the key " + vFlID + " is not in map cUVKeys, please " +
"check!"));
return FLClientStatus.FAILED;
}
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), "CBC");
byte[] encryptData = aesEncrypt.encrypt(cUVKeys.get(vFlID), allSecret); byte[] encryptData = aesEncrypt.encrypt(cUVKeys.get(vFlID), allSecret);
if (encryptData == null || encryptData.length == 0) {
LOGGER.severe(Common.addTag("[encryptShares] the return byte[] is null, please check!"));
return FLClientStatus.FAILED;
}
NewArray<byte[]> array = new NewArray<>();
array.setSize(encryptData.length); array.setSize(encryptData.length);
array.setArray(encryptData); array.setArray(encryptData);
EncryptShare encryptShare = new EncryptShare();
encryptShare.setFlID(vFlID); encryptShare.setFlID(vFlID);
encryptShare.setShare(array); encryptShare.setShare(array);
encryptShareList.add(encryptShare); encryptShareList.add(encryptShare);
} }
setClientShareList(encryptShareList); setClientShareList(encryptShareList);
return FLClientStatus.SUCCESS;
} }
public float[] doubleMaskingWeight() throws Exception { /**
int size = u2UClientList.size(); * get masked weight of secure aggregation
*
* @return masked weight
*/
public float[] doubleMaskingWeight() {
List<Float> noiseBu = new ArrayList<>(); List<Float> noiseBu = new ArrayList<>();
random.randomAESCTR(noiseBu, featureSize, bu); int tag = masking.getMasking(noiseBu, featureSize, bu, individualIv);
if (tag == -1) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the return value is -1, please check!"));
return new float[0];
}
float[] mask = new float[featureSize]; float[] mask = new float[featureSize];
for (int i = 0; i < size; i++) { for (String vFlID : u2UClientList) {
String vFlID = u2UClientList.get(i); if (!clientPublicKeyList.containsKey(vFlID)) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the key " + vFlID + " is not in map " +
"clientPublicKeyList, please check!"));
return new float[0];
}
ClientPublicKey curPublicKey = clientPublicKeyList.get(vFlID); ClientPublicKey curPublicKey = clientPublicKeyList.get(vFlID);
if (localFLParameter.getFlID().equals(vFlID)) { if (localFLParameter.getFlID().equals(vFlID)) {
continue; continue;
}
byte[] salt;
byte[] iVec;
if (vFlID.compareTo(localFLParameter.getFlID()) < 0) {
salt = curPublicKey.getPwSalt().getArray();
iVec = curPublicKey.getPwIv().getArray();
} else { } else {
byte[] salt = new byte[0]; salt = this.pwSalt;
byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray()); iVec = this.pwIVec;
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt); }
sUVKeys.put(vFlID, secret); if (sKey.size() < 2) {
List<Float> noiseSuv = new ArrayList<>(); LOGGER.severe(Common.addTag("[doubleMaskingWeight] the size of sKey is not valid: < 2, it should be " +
random.randomAESCTR(noiseSuv, featureSize, secret); ">= 2, please check!"));
int sign; return new float[0];
if (localFLParameter.getFlID().compareTo(vFlID) > 0) { }
sign = 1; byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray());
} else { if (secret1 == null || secret1.length == 0) {
sign = -1; LOGGER.severe(Common.addTag("[doubleMaskingWeight] the returned secret1 is null, please check!"));
} return new float[0];
for (int j = 0; j < noiseSuv.size(); j++) { }
mask[j] = mask[j] + sign * noiseSuv.get(j); byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
} if (secret == null || secret.length == 0) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the returned secret is null, please check!"));
return new float[0];
}
sUVKeys.put(vFlID, secret);
List<Float> noiseSuv = new ArrayList<>();
tag = masking.getMasking(noiseSuv, featureSize, secret, iVec);
if (tag == -1) {
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the return value is -1, please check!"));
return new float[0];
}
int sign;
if (localFLParameter.getFlID().compareTo(vFlID) > 0) {
sign = 1;
} else {
sign = -1;
}
for (int maskIndex = 0; maskIndex < noiseSuv.size(); maskIndex++) {
mask[maskIndex] = mask[maskIndex] + sign * noiseSuv.get(maskIndex);
} }
} }
for (int j = 0; j < noiseBu.size(); j++) { for (int maskIndex = 0; maskIndex < noiseBu.size(); maskIndex++) {
mask[j] = mask[j] + noiseBu.get(j); mask[maskIndex] = mask[maskIndex] + noiseBu.get(maskIndex);
} }
return mask; return mask;
} }
public NewArray<byte[]> byteToArray(ByteBuffer buf, int size) { private NewArray<byte[]> byteToArray(ByteBuffer buf, int size) {
NewArray<byte[]> newArray = new NewArray<>(); NewArray<byte[]> newArray = new NewArray<>();
newArray.setSize(size); newArray.setSize(size);
byte[] array = new byte[size]; byte[] array = new byte[size];
@ -258,40 +401,80 @@ public class CipherClient {
return newArray; return newArray;
} }
public FLClientStatus requestExchangeKeys() { private FLClientStatus requestExchangeKeys() {
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "==============")); LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() +
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); "=============="));
genDHKeyPairs(); FLClientStatus status = genDHKeyPairs();
if (status == FLClientStatus.FAILED) {
LOGGER.severe(Common.addTag("[requestExchangeKeys] the return status is FAILED, please check!"));
return FLClientStatus.FAILED;
}
if (cKey.size() <= 0 || sKey.size() <= 0) {
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of cKey or sKey is not valid: <=0."));
return FLClientStatus.FAILED;
}
if (cKey.size() < 2) {
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of cKey is not valid: < 2, it should be >= 2," +
" please check!"));
return FLClientStatus.FAILED;
}
if (sKey.size() < 2) {
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of sKey is not valid: < 2, it should be >= 2," +
" please check!"));
return FLClientStatus.FAILED;
}
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
byte[] cPK = cKey.get(0); byte[] cPK = cKey.get(0);
byte[] sPK = sKey.get(0); byte[] sPK = sKey.get(0);
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID());
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK); int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK); int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
String dateTime = LocalDateTime.now().toString();
byte[] indIv = new byte[I_VEC_LEN];
byte[] pwIv = new byte[I_VEC_LEN];
byte[] thisPwSalt = new byte[SALT_SIZE];
SecureRandom secureRandom = Common.getSecureRandom();
secureRandom.nextBytes(indIv);
secureRandom.nextBytes(pwIv);
secureRandom.nextBytes(thisPwSalt);
this.individualIv = indIv;
this.pwIVec = pwIv;
this.pwSalt = thisPwSalt;
int indIvFbs = RequestExchangeKeys.createIndIvVector(fbBuilder, indIv);
int pwIvFbs = RequestExchangeKeys.createPwIvVector(fbBuilder, pwIv);
int pwSaltFbs = RequestExchangeKeys.createPwSaltVector(fbBuilder, thisPwSalt);
int id = fbBuilder.createString(localFLParameter.getFlID());
Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime); int time = fbBuilder.createString(dateTime);
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time);
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time
, indIvFbs, pwIvFbs, pwSaltFbs);
fbBuilder.finish(exchangeKeysRoot); fbBuilder.finish(exchangeKeysRoot);
byte[] msg = fbBuilder.sizedByteArray(); byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try { try {
byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg); byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[requestExchangeKeys] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[requestExchangeKeys] the server is not ready now, need wait some time and" +
" " +
"request again"));
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
nextRequestTime = ""; nextRequestTime = "";
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
} }
ByteBuffer buffer = ByteBuffer.wrap(responseData); ByteBuffer buffer = ByteBuffer.wrap(responseData);
ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer); ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer);
FLClientStatus status = judgeRequestExchangeKeys(responseExchangeKeys); return judgeRequestExchangeKeys(responseExchangeKeys);
return status; } catch (IOException ex) {
} catch (Exception e) { LOGGER.severe(Common.addTag("[requestExchangeKeys] catch IOException: " + ex.getMessage()));
e.printStackTrace();
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) { private FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
retCode = bufData.retcode(); retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************")); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -303,7 +486,8 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success")); LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success"));
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime): case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request startFLJob again")); LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime()); setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
case (ResponseCode.RequestError): case (ResponseCode.RequestError):
@ -311,39 +495,43 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys")); LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
default: default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys is invalid: " + retCode)); LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys " +
"is invalid: " + retCode));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
public FLClientStatus getExchangeKeys() { private FLClientStatus getExchangeKeys() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID()); int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString(); Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime); int time = fbBuilder.createString(dateTime);
int getExchangeKeysRoot = GetExchangeKeys.createGetExchangeKeys(fbBuilder, id, iteration, time); int getExchangeKeysRoot = GetExchangeKeys.createGetExchangeKeys(fbBuilder, id, iteration, time);
fbBuilder.finish(getExchangeKeysRoot); fbBuilder.finish(getExchangeKeysRoot);
byte[] msg = fbBuilder.sizedByteArray(); byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try { try {
byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg); byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getExchangeKeys] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[getExchangeKeys] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
nextRequestTime = ""; nextRequestTime = "";
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
} }
ByteBuffer buffer = ByteBuffer.wrap(responseData); ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer); ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer);
FLClientStatus status = judgeGetExchangeKeys(returnExchangeKeys); return judgeGetExchangeKeys(returnExchangeKeys);
return status; } catch (IOException ex) {
} catch (Exception e) { LOGGER.severe(Common.addTag("[getExchangeKeys] catch IOException: " + ex.getMessage()));
e.printStackTrace();
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) { private FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
retCode = bufData.retcode(); retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************")); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -363,17 +551,25 @@ public class CipherClient {
int sizeCpk = bufData.remotePublickeys(i).cPkLength(); int sizeCpk = bufData.remotePublickeys(i).cPkLength();
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer(); ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
int sizeSpk = bufData.remotePublickeys(i).sPkLength(); int sizeSpk = bufData.remotePublickeys(i).sPkLength();
ByteBuffer bufPwIv = bufData.remotePublickeys(i).pwIvAsByteBuffer();
int sizePwIv = bufData.remotePublickeys(i).pwIvLength();
ByteBuffer bufPwSalt = bufData.remotePublickeys(i).pwSaltAsByteBuffer();
int sizePwSalt = bufData.remotePublickeys(i).pwSaltLength();
publicKey.setCPK(byteToArray(bufCpk, sizeCpk)); publicKey.setCPK(byteToArray(bufCpk, sizeCpk));
publicKey.setSPK(byteToArray(bufSpk, sizeSpk)); publicKey.setSPK(byteToArray(bufSpk, sizeSpk));
publicKey.setPwIv(byteToArray(bufPwIv, sizePwIv));
publicKey.setPwSalt(byteToArray(bufPwSalt, sizePwSalt));
clientPublicKeyList.put(srcFlId, publicKey); clientPublicKeyList.put(srcFlId, publicKey);
u1ClientList.add(srcFlId); u1ClientList.add(srcFlId);
} }
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
case (ResponseCode.SucNotReady): case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetExchangeKeys again!")); LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
"GetExchangeKeys again!"));
return FLClientStatus.WAIT; return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime): case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request startFLJob again")); LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime()); setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
case (ResponseCode.RequestError): case (ResponseCode.RequestError):
@ -381,32 +577,48 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys")); LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
default: default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is invalid: " + retCode)); LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is" +
" invalid: " + retCode));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
public FLClientStatus requestShareSecrets() throws Exception { private FLClientStatus requestShareSecrets() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); FLClientStatus status = genIndividualSecret();
genIndividualSecret(); if (status == FLClientStatus.FAILED) {
genEncryptExchangedKeys(); LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from genIndividualSecret" +
encryptShares(); "(), please check!"));
return FLClientStatus.FAILED;
}
status = genEncryptExchangedKeys();
if (status == FLClientStatus.FAILED) {
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from " +
"genEncryptExchangedKeys(), please check!"));
return FLClientStatus.FAILED;
}
status = encryptShares();
if (status == FLClientStatus.FAILED) {
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from encryptShares(), " +
"please check!"));
return FLClientStatus.FAILED;
}
FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID()); int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString(); Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime); int time = fbBuilder.createString(dateTime);
int clientShareSize = clientShareList.size(); int clientShareSize = clientShareList.size();
if (clientShareSize <= 0) { if (clientShareSize <= 0) {
LOGGER.warning(Common.addTag("[PairWiseMask] encrypt shares is not ready now!")); LOGGER.warning(Common.addTag("[PairWiseMask] encrypt shares is not ready now!"));
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
FLClientStatus status = requestShareSecrets(); return requestShareSecrets();
return status;
} else { } else {
int[] add = new int[clientShareSize]; int[] add = new int[clientShareSize];
for (int i = 0; i < clientShareSize; i++) { for (int i = 0; i < clientShareSize; i++) {
int flID = fbBuilder.createString(clientShareList.get(i).getFlID()); int flID = fbBuilder.createString(clientShareList.get(i).getFlID());
int shareSecretFbs = ClientShare.createShareVector(fbBuilder, clientShareList.get(i).getShare().getArray()); int shareSecretFbs = ClientShare.createShareVector(fbBuilder,
clientShareList.get(i).getShare().getArray());
ClientShare.startClientShare(fbBuilder); ClientShare.startClientShare(fbBuilder);
ClientShare.addFlId(fbBuilder, flID); ClientShare.addFlId(fbBuilder, flID);
ClientShare.addShare(fbBuilder, shareSecretFbs); ClientShare.addShare(fbBuilder, shareSecretFbs);
@ -414,29 +626,33 @@ public class CipherClient {
add[i] = clientShareRoot; add[i] = clientShareRoot;
} }
int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add); int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add);
int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id, encryptedSharesFbs, iteration, time); int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id,
encryptedSharesFbs, iteration, time);
fbBuilder.finish(requestShareSecretsRoot); fbBuilder.finish(requestShareSecretsRoot);
byte[] msg = fbBuilder.sizedByteArray(); byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try { try {
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg); byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[requestShareSecrets] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[requestShareSecrets] the server is not ready now, need wait some time" +
" " +
"and request again"));
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
nextRequestTime = ""; nextRequestTime = "";
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
} }
ByteBuffer buffer = ByteBuffer.wrap(responseData); ByteBuffer buffer = ByteBuffer.wrap(responseData);
ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer); ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer);
FLClientStatus status = judgeRequestShareSecrets(responseShareSecrets); return judgeRequestShareSecrets(responseShareSecrets);
return status; } catch (IOException ex) {
} catch (Exception e) { LOGGER.severe(Common.addTag("[requestShareSecrets] catch IOException: " + ex.getMessage()));
e.printStackTrace();
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
} }
public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) { private FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
retCode = bufData.retcode(); retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************")); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -448,7 +664,8 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success")); LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success"));
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime): case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request startFLJob again")); LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime()); setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
case (ResponseCode.RequestError): case (ResponseCode.RequestError):
@ -456,39 +673,43 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets")); LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
default: default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets is invalid: " + retCode)); LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets " +
"is invalid: " + retCode));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
public FLClientStatus getShareSecrets() { private FLClientStatus getShareSecrets() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
FlatBufferBuilder fbBuilder = new FlatBufferBuilder(); FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
int id = fbBuilder.createString(localFLParameter.getFlID()); int id = fbBuilder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString(); Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = fbBuilder.createString(dateTime); int time = fbBuilder.createString(dateTime);
int getShareSecrets = GetShareSecrets.createGetShareSecrets(fbBuilder, id, iteration, time); int getShareSecrets = GetShareSecrets.createGetShareSecrets(fbBuilder, id, iteration, time);
fbBuilder.finish(getShareSecrets); fbBuilder.finish(getShareSecrets);
byte[] msg = fbBuilder.sizedByteArray(); byte[] msg = fbBuilder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
try { try {
byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg); byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getShareSecrets] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[getShareSecrets] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
nextRequestTime = ""; nextRequestTime = "";
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
} }
ByteBuffer buffer = ByteBuffer.wrap(responseData); ByteBuffer buffer = ByteBuffer.wrap(responseData);
ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer); ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer);
FLClientStatus status = judgeGetShareSecrets(returnShareSecrets); return judgeGetShareSecrets(returnShareSecrets);
return status; } catch (IOException ex) {
} catch (Exception e) { LOGGER.severe(Common.addTag("[getShareSecrets] catch IOException: " + ex.getMessage()));
e.printStackTrace();
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) { private FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
retCode = bufData.retcode(); retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************")); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -503,20 +724,26 @@ public class CipherClient {
int length = bufData.encryptedSharesLength(); int length = bufData.encryptedSharesLength();
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
EncryptShare shareSecret = new EncryptShare(); EncryptShare shareSecret = new EncryptShare();
shareSecret.setFlID(bufData.encryptedShares(i).flId()); ClientShare clientShare = bufData.encryptedShares(i);
ByteBuffer bufShare = bufData.encryptedShares(i).shareAsByteBuffer(); if (clientShare == null) {
int sizeShare = bufData.encryptedShares(i).shareLength(); LOGGER.severe(Common.addTag("[PairWiseMask] the clientShare returned from server is null"));
return FLClientStatus.FAILED;
}
shareSecret.setFlID(clientShare.flId());
ByteBuffer bufShare = clientShare.shareAsByteBuffer();
int sizeShare = clientShare.shareLength();
shareSecret.setShare(byteToArray(bufShare, sizeShare)); shareSecret.setShare(byteToArray(bufShare, sizeShare));
returnShareList.add(shareSecret); returnShareList.add(shareSecret);
u2UClientList.add(bufData.encryptedShares(i).flId()); u2UClientList.add(clientShare.flId());
} }
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
case (ResponseCode.SucNotReady): case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetShareSecrets again!")); LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
"GetShareSecrets again!"));
return FLClientStatus.WAIT; return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime): case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request startFLJob again")); LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime()); setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
case (ResponseCode.RequestError): case (ResponseCode.RequestError):
@ -524,15 +751,22 @@ public class CipherClient {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets")); LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
default: default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is invalid: " + retCode)); LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is" +
" invalid: " + retCode));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
/**
* exchangeKeys round of secure aggregation
*
* @return round execution result
*/
public FLClientStatus exchangeKeys() { public FLClientStatus exchangeKeys() {
LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys ======================")); LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys " +
FLClientStatus curStatus; "======================"));
// RequestExchangeKeys // RequestExchangeKeys
FLClientStatus curStatus;
curStatus = requestExchangeKeys(); curStatus = requestExchangeKeys();
while (curStatus == FLClientStatus.WAIT) { while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
@ -551,8 +785,14 @@ public class CipherClient {
return curStatus; return curStatus;
} }
public FLClientStatus shareSecrets() throws Exception { /**
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets ======================"))); * shareSecrets round of secure aggregation
*
* @return round execution result
*/
public FLClientStatus shareSecrets() {
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets " +
"======================")));
FLClientStatus curStatus; FLClientStatus curStatus;
// RequestShareSecrets // RequestShareSecrets
curStatus = requestShareSecrets(); curStatus = requestShareSecrets();
@ -573,14 +813,22 @@ public class CipherClient {
return curStatus; return curStatus;
} }
/**
* reconstructSecrets round of secure aggregation
*
* @return round execution result
*/
public FLClientStatus reconstructSecrets() { public FLClientStatus reconstructSecrets() {
LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret ========================")); LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret " +
"========================"));
FLClientStatus curStatus; FLClientStatus curStatus;
// GetClientList // GetClientList
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys); curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList,
cUVKeys);
while (curStatus == FLClientStatus.WAIT) { while (curStatus == FLClientStatus.WAIT) {
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys); curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList
, cUVKeys);
} }
if (curStatus == FLClientStatus.RESTART) { if (curStatus == FLClientStatus.RESTART) {
nextRequestTime = clientListReq.getNextRequestTime(); nextRequestTime = clientListReq.getNextRequestTime();

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,10 +13,17 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import org.bouncycastle.crypto.BlockCipher;
import org.bouncycastle.crypto.engines.AESEngine;
import org.bouncycastle.crypto.prng.SP800SecureRandomBuilder;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Date; import java.util.Date;
@ -26,28 +33,94 @@ import java.util.logging.Logger;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
/**
* Define basic global methods used in federated learning task.
*
* @since 2021-06-30
*/
public class Common { public class Common {
/**
* Global logger title.
*/
public static final String LOG_TITLE = "<FLClient> "; public static final String LOG_TITLE = "<FLClient> ";
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "albert"));
public static String generateUrl(boolean useHttps, boolean useElb, String ip, int port, int serverNum) { /**
if (useHttps) { * The list of trust flName.
ip = "https://" + ip + ":"; */
} else { public static final List<String> FL_NAME_TRUST_LIST = new ArrayList<>(Arrays.asList("lenet", "albert"));
ip = "http://" + ip + ":";
/**
* The list of trust ssl protocol.
*/
public static final List<String> SSL_PROTOCOL_TRUST_LIST = new ArrayList<>(Arrays.asList("TLSv1.3", "TLSv1.2"));
/**
* The tag when server is in safe mode.
*/
public static final String SAFE_MOD = "The cluster is in safemode.";
/**
* The tag when server is not ready.
*/
public static final String JOB_NOT_AVAILABLE = "The server's training job is disabled or finished.";
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
private static SecureRandom secureRandom;
/**
* Generate the URL for device-sever interaction
*
* @param ifUseElb whether a client randomly sends a request to a server address within a specified range.
* @param serverNum number of servers that can send requests.
* @param domainName the URL for device-sever interaction set by user.
* @return the URL for device-sever interaction.
*/
public static String generateUrl(boolean ifUseElb, int serverNum, String domainName) {
if (serverNum <= 0) {
LOGGER.severe(Common.addTag("[generateUrl] the input argument <serverNum> is not valid: <= 0, it should " +
"be > 0, please check!"));
throw new IllegalArgumentException();
} }
String url; String url;
if (useElb) { if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
LOGGER.severe(Common.addTag("[generateUrl] the input argument <domainName> is null or not valid, it " +
"should be like as https://...... or http://...... , please check!"));
throw new IllegalArgumentException();
}
if (ifUseElb) {
if (domainName.split("//")[1].split(":").length < 2) {
LOGGER.severe(Common.addTag("[generateUrl] the format of <domainName> is not valid, it should be like" +
" as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, please check!"));
throw new IllegalArgumentException();
}
String ip = domainName.split("//")[1].split(":")[0];
int port = Integer.parseInt(domainName.split("//")[1].split(":")[1]);
if (!Common.checkIP(ip)) {
LOGGER.severe(Common.addTag("[generateUrl] the <ip> split from domainName is not valid, domainName " +
"should be like as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, " +
"please check!"));
throw new IllegalArgumentException();
}
if (!Common.checkPort(port)) {
LOGGER.severe(Common.addTag("[generateUrl] the <port> split from domainName is not valid, domainName " +
"should be like as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, " +
"please check!"));
throw new IllegalArgumentException();
}
String tag = domainName.split("//")[0] + "//";
Random rand = new Random(); Random rand = new Random();
int randomNum = rand.nextInt(100000) % serverNum + port; int randomNum = rand.nextInt(100000) % serverNum + port;
url = ip + String.valueOf(randomNum); url = tag + ip + ":" + String.valueOf(randomNum);
} else { } else {
url = ip + String.valueOf(port); url = domainName;
} }
return url; return url;
} }
/**
* Store weight name of classifier to a list.
*
* @param classifierWeightName the list to store weight name of classifier.
*/
public static void setClassifierWeightName(List<String> classifierWeightName) { public static void setClassifierWeightName(List<String> classifierWeightName) {
classifierWeightName.add("albert.pooler.weight"); classifierWeightName.add("albert.pooler.weight");
classifierWeightName.add("albert.pooler.bias"); classifierWeightName.add("albert.pooler.bias");
@ -56,6 +129,11 @@ public class Common {
LOGGER.info(addTag("classifierWeightName size: " + classifierWeightName.size())); LOGGER.info(addTag("classifierWeightName size: " + classifierWeightName.size()));
} }
/**
* Store weight name of albert network to a list.
*
* @param albertWeightName the list to store weight name of albert network.
*/
public static void setAlbertWeightName(List<String> albertWeightName) { public static void setAlbertWeightName(List<String> albertWeightName) {
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.weight"); albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.weight");
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.bias"); albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.bias");
@ -78,32 +156,67 @@ public class Common {
LOGGER.info(addTag("albertWeightName size: " + albertWeightName.size())); LOGGER.info(addTag("albertWeightName size: " + albertWeightName.size()));
} }
/**
* Check whether the flName set by user is in the trust list.
*
* @param flName the model name set by user.
* @return boolean value, true indicates the flName set by user is valid, false indicates the flName set by user
* is not valid.
*/
public static boolean checkFLName(String flName) { public static boolean checkFLName(String flName) {
return (flNameTrustList.contains(flName)); return (FL_NAME_TRUST_LIST.contains(flName));
} }
/**
* Check whether the sslProtocol set by user is in the trust list.
*
* @param sslProtocol the ssl protocol set by user.
* @return boolean value, true indicates the sslProtocol set by user is valid, false indicates the sslProtocol
* set by user is not valid.
*/
public static boolean checkSSLProtocol(String sslProtocol) {
return (SSL_PROTOCOL_TRUST_LIST.contains(sslProtocol));
}
/**
* The program waits for the specified time and then to continue.
*
* @param millis the waiting time (ms).
*/
public static void sleep(long millis) { public static void sleep(long millis) {
try { try {
Thread.sleep(millis); //1000 milliseconds is one second. Thread.sleep(millis); // 1000 milliseconds is one second.
} catch (InterruptedException ex) { } catch (InterruptedException ex) {
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage())); LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
Thread.currentThread().interrupt(); Thread.currentThread().interrupt();
} }
} }
/**
* Get the waiting time for repeated requests.
*
* @param nextRequestTime the timestamp return from server.
* @return the waiting time for repeated requests.
*/
public static long getWaitTime(String nextRequestTime) { public static long getWaitTime(String nextRequestTime) {
Date date = new Date(); Date date = new Date();
long currentTime = date.getTime(); long currentTime = date.getTime();
long waitTime = 0; long waitTime = 0L;
if (!("").equals(nextRequestTime)) { if (!(nextRequestTime == null || nextRequestTime.isEmpty())) {
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime); waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
} }
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime)); LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " +
currentTime));
LOGGER.info(addTag("[getWaitTime] waitTime: " + waitTime)); LOGGER.info(addTag("[getWaitTime] waitTime: " + waitTime));
return waitTime; return waitTime;
} }
/**
* Get start time.
*
* @param tag the tag added to the logger.
* @return start time.
*/
public static long startTime(String tag) { public static long startTime(String tag) {
Date startDate = new Date(); Date startDate = new Date();
long startTime = startDate.getTime(); long startTime = startDate.getTime();
@ -111,6 +224,12 @@ public class Common {
return startTime; return startTime;
} }
/**
* Get end time.
*
* @param start the start time.
* @param tag the tag added to the logger.
*/
public static void endTime(long start, String tag) { public static void endTime(long start, String tag) {
Date endDate = new Date(); Date endDate = new Date();
long endTime = endDate.getTime(); long endTime = endDate.getTime();
@ -118,53 +237,182 @@ public class Common {
LOGGER.info(addTag("[interval time] <" + tag + "> interval time(ms): " + (endTime - start))); LOGGER.info(addTag("[interval time] <" + tag + "> interval time(ms): " + (endTime - start)));
} }
/**
* Add specified tag to the message.
*
* @param message the message need to add tag.
* @return the message after adding tag.
*/
public static String addTag(String message) { public static String addTag(String message) {
return LOG_TITLE + message; return LOG_TITLE + message;
} }
public static boolean isSafeMod(byte[] message, String safeModTag) { /**
return (new String(message)).contains(safeModTag); * Check whether the server is ready based on the message returned by the server.
*
* @param message the message returned by the server..
* @return boolean value, true indicates the server is ready, false indicates the server is not ready.
*/
public static boolean isSeverReady(byte[] message) {
if (message == null) {
LOGGER.severe(Common.addTag("[isSeverReady] the input argument <message> is null, please check!"));
throw new IllegalArgumentException();
}
String messageStr = new String(message);
if (messageStr.contains(SAFE_MOD)) {
LOGGER.info(Common.addTag("[isSeverReady] " + SAFE_MOD + ", need wait some time and request again"));
return false;
} else if (messageStr.contains(JOB_NOT_AVAILABLE)) {
LOGGER.info(Common.addTag("[isSeverReady] " + JOB_NOT_AVAILABLE + ", need wait some time and request " +
"again"));
return false;
} else {
return true;
}
} }
public static String getRealPath (String path) { /**
LOGGER.info(addTag("[original path] " + path)); * Convert a user-set path to a standard path.
*
* @param path the user-set path.
* @return the standard path.
*/
public static String getRealPath(String path) {
if (path == null) {
LOGGER.severe(Common.addTag("[getRealPath] the input argument <path> is null, please check!"));
throw new IllegalArgumentException();
}
LOGGER.info(addTag("[getRealPath] original path: " + path));
String[] paths = path.split(","); String[] paths = path.split(",");
for (int i = 0; i < paths.length; i++) { for (int i = 0; i < paths.length; i++) {
LOGGER.info(addTag("[original path " + i + "] " + paths[i])); if (paths[i] == null) {
LOGGER.severe(Common.addTag("[getRealPath] the paths[i] is null, please check"));
throw new IllegalArgumentException();
}
LOGGER.info(addTag("[getRealPath] original path " + i + ": " + paths[i]));
File file = new File(paths[i]); File file = new File(paths[i]);
try { try {
paths[i] = file.getCanonicalPath(); paths[i] = file.getCanonicalPath();
} catch (IOException e) { } catch (IOException e) {
LOGGER.severe(addTag("[checkPath] catch IOException in file.getCanonicalPath(): " + e.getMessage())); LOGGER.severe(addTag("[getRealPath] catch IOException in file.getCanonicalPath(): " + e.getMessage()));
throw new RuntimeException(); throw new IllegalArgumentException();
} }
} }
path = String.join(",", Arrays.asList(paths)); String realPath = String.join(",", Arrays.asList(paths));
LOGGER.info(addTag("[real path] " + path)); LOGGER.info(addTag("[getRealPath] real path: " + realPath));
return path; return realPath;
} }
/**
* Check whether the path set by user exists.
*
* @param path the path set by user.
* @return boolean value, true indicates the path is exist, false indicates the path is not exist
*/
public static boolean checkPath(String path) { public static boolean checkPath(String path) {
boolean tag = true; if (path == null) {
LOGGER.severe(Common.addTag("[checkPath] the input argument <path> is null, please check!"));
return false;
}
String[] paths = path.split(","); String[] paths = path.split(",");
for (int i = 0; i < paths.length; i++) { for (int i = 0; i < paths.length; i++) {
if (paths[i] == null) {
LOGGER.severe(Common.addTag("[checkPath] the paths[i] is null, please check"));
return false;
}
LOGGER.info(addTag("[check path " + i + "] " + paths[i])); LOGGER.info(addTag("[check path " + i + "] " + paths[i]));
File file = new File(paths[i]); File file = new File(paths[i]);
if (!file.exists()) { if (!file.exists()) {
tag = false; LOGGER.severe(Common.addTag("[checkPath] the path is not exist, please check"));
return false;
} }
} }
return tag; return true;
} }
/**
* Check whether the ip set by user is valid.
*
* @param ip the ip set by user.
* @return boolean value, true indicates the ip is valid, false indicates the ip is not valid.
*/
public static boolean checkIP(String ip) { public static boolean checkIP(String ip) {
String regex = "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])"; if (ip == null) {
LOGGER.severe(Common.addTag("[checkIP] the input argument <ip> is null, please check!"));
throw new IllegalArgumentException();
}
String regex = "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])[.]" +
"(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.]" +
"(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.]" +
"(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])";
Pattern pattern = Pattern.compile(regex); Pattern pattern = Pattern.compile(regex);
Matcher matcher = pattern.matcher(ip); Matcher matcher = pattern.matcher(ip);
return matcher.matches(); return matcher.matches();
} }
/**
* Check whether the port set by user is valid.
*
* @param port the port set by user.
* @return boolean value, true indicates the port is valid, false indicates the port is not valid.
*/
public static boolean checkPort(int port) { public static boolean checkPort(int port) {
return port > 0 && port <= 65535; return port > 0 && port <= 65535;
} }
/**
* Obtain secure random.
*
* @return the secure random.
*/
public static SecureRandom getSecureRandom() {
if (secureRandom == null) {
LOGGER.severe(Common.addTag("[setSecureRandom] the parameter secureRandom is null, please set it before " +
"use"));
throw new IllegalArgumentException();
}
return secureRandom;
}
/**
* Set the secure random to parameter secureRandom of the class Common.
*
* @param secureRandom the secure random.
*/
public static void setSecureRandom(SecureRandom secureRandom) {
if (secureRandom == null) {
LOGGER.severe(Common.addTag("[setSecureRandom] the input parameter secureRandom is null, please check"));
throw new IllegalArgumentException();
}
Common.secureRandom = secureRandom;
}
/**
* Obtain fast secure random.
*
* @return the fast secure random.
*/
public static SecureRandom getFastSecureRandom() {
try {
LOGGER.info(Common.addTag("[getFastSecureRandom] start create fastSecureRandom"));
long start = System.currentTimeMillis();
SecureRandom blockingRandom = SecureRandom.getInstanceStrong();
boolean ifPredictionResistant = true;
BlockCipher cipher = new AESEngine();
int cipherLen = 256;
int entropyBitsRequired = 384;
byte[] nonce = null;
boolean ifForceReseed = false;
SecureRandom fastRandom = new SP800SecureRandomBuilder(blockingRandom, ifPredictionResistant)
.setEntropyBitsRequired(entropyBitsRequired)
.buildCTR(cipher, cipherLen, nonce, ifForceReseed);
fastRandom.nextInt();
LOGGER.info(Common.addTag("[getFastSecureRandom] finish create fastSecureRandom"));
LOGGER.info(Common.addTag("[getFastSecureRandom] cost time: " + (System.currentTimeMillis() - start)));
return fastRandom;
} catch (NoSuchAlgorithmException e) {
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException: " + e.getMessage()));
throw new IllegalArgumentException();
}
}
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,13 +13,17 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
/**
* The early stop mod.
*
* @since 2021-06-30
*/
public enum EarlyStopMod { public enum EarlyStopMod {
LOSS_DIFF, LOSS_DIFF,
LOSS_ABS, LOSS_ABS,
WEIGHT_DIFF, WEIGHT_DIFF,
NOT_EARLY_STOP NOT_EARLY_STOP
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,12 +13,16 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
/**
* Security encryption level.
*
* @since 2021-06-30
*/
public enum EncryptLevel { public enum EncryptLevel {
PW_ENCRYPT, PW_ENCRYPT,
DP_ENCRYPT, DP_ENCRYPT,
NOT_ENCRYPT NOT_ENCRYPT
} }

View File

@ -1,21 +1,26 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
/**
* The status code of federated learning.
*
* @since 2021-06-30
*/
public enum FLClientStatus { public enum FLClientStatus {
SUCCESS, SUCCESS,
FAILED, FAILED,

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,6 +16,8 @@
package com.mindspore.flclient; package com.mindspore.flclient;
import static com.mindspore.flclient.FLParameter.TIME_OUT;
import okhttp3.Call; import okhttp3.Call;
import okhttp3.Callback; import okhttp3.Callback;
import okhttp3.MediaType; import okhttp3.MediaType;
@ -24,12 +26,6 @@ import okhttp3.Request;
import okhttp3.RequestBody; import okhttp3.RequestBody;
import okhttp3.Response; import okhttp3.Response;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException; import java.io.IOException;
import java.security.KeyManagementException; import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
@ -39,34 +35,40 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.TIME_OUT; import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
/**
* Define the communication interface.
*
* @since 2021-06-30
*/
public class FLCommunication implements IFLCommunication { public class FLCommunication implements IFLCommunication {
private static int timeOut; private static int timeOut;
private static boolean ssl = false; private static boolean ifCertificateVerify = false;
private static String env;
private static SSLSocketFactory sslSocketFactory;
private static X509TrustManager x509TrustManager;
private FLParameter flParameter = FLParameter.getInstance();
private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8"); private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8");
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString()); private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
private OkHttpClient client;
private static volatile FLCommunication communication; private static volatile FLCommunication communication;
private FLParameter flParameter = FLParameter.getInstance();
private OkHttpClient client;
private FLCommunication() { private FLCommunication() {
if (flParameter.getTimeOut() != 0) { if (flParameter.getTimeOut() != 0) {
timeOut = flParameter.getTimeOut(); timeOut = flParameter.getTimeOut();
} else { } else {
timeOut = TIME_OUT; timeOut = TIME_OUT;
} }
ssl = flParameter.isUseSSL(); ifCertificateVerify = flParameter.isUseSSL();
client = getOkHttpClient(); client = getOkHttpClient();
} }
private static OkHttpClient getOkHttpClient() { private static OkHttpClient getOkHttpClient() {
X509TrustManager trustManager = new X509TrustManager() { X509TrustManager trustManager = new X509TrustManager() {
@Override @Override
public X509Certificate[] getAcceptedIssuers() { public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[]{}; return new X509Certificate[]{};
@ -89,14 +91,15 @@ public class FLCommunication implements IFLCommunication {
builder.connectTimeout(timeOut, TimeUnit.SECONDS); builder.connectTimeout(timeOut, TimeUnit.SECONDS);
builder.writeTimeout(timeOut, TimeUnit.SECONDS); builder.writeTimeout(timeOut, TimeUnit.SECONDS);
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS); builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
if (ssl) { if (ifCertificateVerify) {
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(), SSLSocketFactoryTools.getInstance().getmTrustManager()); builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(),
SSLSocketFactoryTools.getInstance().getmTrustManager());
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier()); builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
} else { } else {
final SSLContext sslContext = SSLContext.getInstance("TLS"); final SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustAllCerts, new java.security.SecureRandom()); sslContext.init(null, trustAllCerts, Common.getSecureRandom());
final javax.net.ssl.SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory(); final SSLSocketFactory sslFactory = sslContext.getSocketFactory();
builder.sslSocketFactory(sslSocketFactory, trustManager); builder.sslSocketFactory(sslFactory, trustManager);
builder.hostnameVerifier(new HostnameVerifier() { builder.hostnameVerifier(new HostnameVerifier() {
@Override @Override
public boolean verify(String arg0, SSLSession arg1) { public boolean verify(String arg0, SSLSession arg1) {
@ -104,14 +107,18 @@ public class FLCommunication implements IFLCommunication {
} }
}); });
} }
return builder.build(); return builder.build();
} catch (NoSuchAlgorithmException | KeyManagementException e) { } catch (NoSuchAlgorithmException | KeyManagementException ex) {
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + e.getMessage())); LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + ex.getMessage()));
throw new RuntimeException(e); throw new IllegalArgumentException(ex);
} }
} }
/**
* Get the singleton object of the class FLCommunication.
*
* @return the singleton object of the class FLCommunication.
*/
public static FLCommunication getInstance() { public static FLCommunication getInstance() {
FLCommunication localRef = communication; FLCommunication localRef = communication;
if (localRef == null) { if (localRef == null) {
@ -138,6 +145,9 @@ public class FLCommunication implements IFLCommunication {
if (!response.isSuccessful()) { if (!response.isSuccessful()) {
throw new IOException("Unexpected code " + response); throw new IOException("Unexpected code " + response);
} }
if (response.body() == null) {
throw new IOException("the returned response is null");
}
return response.body().bytes(); return response.body().bytes();
} }
@ -159,11 +169,10 @@ public class FLCommunication implements IFLCommunication {
} }
@Override @Override
public void onFailure(Call call, IOException e) { public void onFailure(Call call, IOException ioException) {
asyncCallBack.onFailure(e); asyncCallBack.onFailure(ioException);
call.cancel(); call.cancel();
} }
}); });
} }
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,19 +13,40 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import java.util.logging.Logger; import java.util.logging.Logger;
/**
* Define job result callback function.
*
* @since 2021-06-30
*/
public class FLJobResultCallback implements IFLJobResultCallback { public class FLJobResultCallback implements IFLJobResultCallback {
private static final Logger LOGGER = Logger.getLogger(FLJobResultCallback.class.toString()); private static final Logger LOGGER = Logger.getLogger(FLJobResultCallback.class.toString());
/**
* Called at the end of an iteration for Fl job
*
* @param modelName the name of model
* @param iterationSeq Iteration number
* @param resultCode Status Code
*/
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) { public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) {
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode)); LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " +
iterationSeq + " resultCode: " + resultCode));
} }
/**
* Called on completion for Fl job
*
* @param modelName the name of model
* @param iterationCount total Iteration numbers
* @param resultCode Status Code
*/
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) { public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode)); LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " +
iterationCount + " resultCode: " + resultCode));
} }
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,11 +16,15 @@
package com.mindspore.flclient; package com.mindspore.flclient;
import com.mindspore.flclient.cipher.BaseUtil; import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.mindspore.flclient.model.AlInferBert; import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet; import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.CipherPublicParams; import mindspore.schema.CipherPublicParams;
import mindspore.schema.FLPlan; import mindspore.schema.FLPlan;
import mindspore.schema.ResponseCode; import mindspore.schema.ResponseCode;
@ -36,18 +40,21 @@ import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME; /**
import static com.mindspore.flclient.LocalFLParameter.ALBERT; * Defining the general process of federated learning tasks.
import static com.mindspore.flclient.LocalFLParameter.LENET; *
* @since 2021-06-30
*/
public class FLLiteClient { public class FLLiteClient {
private static final Logger LOGGER = Logger.getLogger(FLLiteClient.class.toString()); private static final Logger LOGGER = Logger.getLogger(FLLiteClient.class.toString());
private FLCommunication flCommunication; private static int iteration = 0;
private static Map<String, float[]> mapBeforeTrain;
private double dpNormClipFactor = 1.0d;
private double dpNormClipAdapt = 0.05d;
private FLCommunication flCommunication;
private FLClientStatus status; private FLClientStatus status;
private int retCode; private int retCode;
private static int iteration = 0;
private int iterations = 1; private int iterations = 1;
private int epochs = 1; private int epochs = 1;
private int batchSize = 16; private int batchSize = 16;
@ -55,22 +62,21 @@ public class FLLiteClient {
private byte[] prime; private byte[] prime;
private int featureSize; private int featureSize;
private int trainDataSize; private int trainDataSize;
private double dpEps = 100; private double dpEps = 100d;
private double dpDelta = 0.01; private double dpDelta = 0.01d;
public double dpNormClipFactor = 1.0;
public double dpNormClipAdapt = 0.05;
private FLParameter flParameter = FLParameter.getInstance(); private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private SecureProtocol secureProtocol = new SecureProtocol(); private SecureProtocol secureProtocol = new SecureProtocol();
private static Map<String, float[]> mapBeforeTrain;
private String nextRequestTime; private String nextRequestTime;
/**
* Defining a constructor of teh class FLLiteClient.
*/
public FLLiteClient() { public FLLiteClient() {
flCommunication = FLCommunication.getInstance(); flCommunication = FLCommunication.getInstance();
} }
public int setGlobalParameters(ResponseFLJob flJob) { private int setGlobalParameters(ResponseFLJob flJob) {
FLPlan flPlan = flJob.flPlanConfig(); FLPlan flPlan = flJob.flPlanConfig();
if (flPlan == null) { if (flPlan == null) {
LOGGER.severe(Common.addTag("[startFLJob] the FLPlan get from server is null")); LOGGER.severe(Common.addTag("[startFLJob] the FLPlan get from server is null"));
@ -90,14 +96,22 @@ public class FLLiteClient {
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize)); LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
trainLenet.setBatchSize(batchSize); trainLenet.setBatchSize(batchSize);
} else {
LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid"));
return -1;
} }
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize));
CipherPublicParams cipherPublicParams = flPlan.cipher(); CipherPublicParams cipherPublicParams = flPlan.cipher();
if (cipherPublicParams == null) {
LOGGER.severe(Common.addTag("[startFLJob] the cipherPublicParams returned from server is null"));
return -1;
}
String encryptLevel = cipherPublicParams.encryptType(); String encryptLevel = cipherPublicParams.encryptType();
if ("".equals(encryptLevel) || encryptLevel.isEmpty()) { if (encryptLevel == null || encryptLevel.isEmpty()) {
LOGGER.severe(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server is null, set the encryptLevel to NOT_ENCRYPT ")); LOGGER.severe(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server is null, set the " +
"encryptLevel to NOT_ENCRYPT "));
localFLParameter.setEncryptLevel(EncryptLevel.NOT_ENCRYPT.toString()); localFLParameter.setEncryptLevel(EncryptLevel.NOT_ENCRYPT.toString());
} else { } else {
localFLParameter.setEncryptLevel(encryptLevel); localFLParameter.setEncryptLevel(encryptLevel);
@ -113,10 +127,10 @@ public class FLLiteClient {
} }
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
if (minSecretNum <= 0) { if (minSecretNum <= 0) {
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid: <=0")); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid:" +
" <=0"));
return -1; return -1;
} }
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
break; break;
case DP_ENCRYPT: case DP_ENCRYPT:
dpEps = cipherPublicParams.dpEps(); dpEps = cipherPublicParams.dpEps();
@ -124,53 +138,97 @@ public class FLLiteClient {
dpNormClipFactor = cipherPublicParams.dpNormClip(); dpNormClipFactor = cipherPublicParams.dpNormClip();
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " + dpNormClipFactor)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " +
dpNormClipFactor));
break; break;
default: default:
LOGGER.info(Common.addTag("[startFLJob] NotEncrypt, do not set parameter for Encrypt")); LOGGER.info(Common.addTag("[startFLJob] NOT_ENCRYPT, do not set parameter for Encrypt"));
} }
return 0; return 0;
} }
/**
* Obtain retCode returned by server.
*
* @return the retCode returned by server.
*/
public int getRetCode() { public int getRetCode() {
return retCode; return retCode;
} }
/**
* Obtain current iteration returned by server.
*
* @return the current iteration returned by server.
*/
public int getIteration() { public int getIteration() {
return iteration; return iteration;
} }
/**
* Obtain total iterations for the task returned by server.
*
* @return the total iterations for the task returned by server.
*/
public int getIterations() { public int getIterations() {
return iterations; return iterations;
} }
public int getEpochs() { /**
return epochs; * Obtain the returned timestamp for next request from server.
} *
* @return the timestamp for next request.
public int getBatchSize() { */
return batchSize;
}
public String getNextRequestTime() { public String getNextRequestTime() {
return nextRequestTime; return nextRequestTime;
} }
public void setNextRequestTime(String nextRequestTime) { /**
this.nextRequestTime = nextRequestTime; * set the size of train date set.
} *
* @param trainDataSize the size of train date set.
*/
public void setTrainDataSize(int trainDataSize) { public void setTrainDataSize(int trainDataSize) {
this.trainDataSize = trainDataSize; this.trainDataSize = trainDataSize;
} }
public FLClientStatus checkStatus() { /**
return this.status; * Obtain the dpNormClipFactor.
*
* @return the dpNormClipFactor.
*/
public double getDpNormClipFactor() {
return dpNormClipFactor;
} }
/**
* Obtain the dpNormClipAdapt.
*
* @return the dpNormClipAdapt.
*/
public double getDpNormClipAdapt() {
return dpNormClipAdapt;
}
/**
* Set the dpNormClipAdapt.
*
* @param dpNormClipAdapt the dpNormClipAdapt.
*/
public void setDpNormClipAdapt(double dpNormClipAdapt) {
this.dpNormClipAdapt = dpNormClipAdapt;
}
/**
* Send serialized request message of startFLJob to server.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus startFLJob() { public FLClientStatus startFLJob() {
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify server====================================")); LOGGER.info(Common.addTag("[startFLJob] ====================================Verify " +
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); "server===================================="));
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
StartFLJob startFLJob = StartFLJob.getInstance(); StartFLJob startFLJob = StartFLJob.getInstance();
Date date = new Date(); Date date = new Date();
long time = date.getTime(); long time = date.getTime();
@ -179,8 +237,9 @@ public class FLLiteClient {
long start = Common.startTime("single startFLJob"); long start = Common.startTime("single startFLJob");
LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length)); LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length));
byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg); byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(message)) {
LOGGER.info(Common.addTag("[startFLJob] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[startFLJob] the server is not ready now, need wait some time and request " +
"again"));
status = FLClientStatus.RESTART; status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
nextRequestTime = ""; nextRequestTime = "";
@ -193,14 +252,15 @@ public class FLLiteClient {
status = judgeStartFLJob(startFLJob, responseDataBuf); status = judgeStartFLJob(startFLJob, responseDataBuf);
retCode = responseDataBuf.retcode(); retCode = responseDataBuf.retcode();
} catch (IOException e) { } catch (IOException e) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage())); LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " +
e.getMessage()));
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
} }
return status; return status;
} }
public FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) { private FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
iteration = responseDataBuf.iteration(); iteration = responseDataBuf.iteration();
FLClientStatus response = startFLJob.doResponse(responseDataBuf); FLClientStatus response = startFLJob.doResponse(responseDataBuf);
status = response; status = response;
@ -218,6 +278,10 @@ public class FLLiteClient {
break; break;
case RESTART: case RESTART:
FLPlan flPlan = responseDataBuf.flPlanConfig(); FLPlan flPlan = responseDataBuf.flPlanConfig();
if (flPlan == null) {
LOGGER.severe(Common.addTag("[startFLJob] the flPlan returned from server is null"));
return FLClientStatus.FAILED;
}
iterations = flPlan.iterations(); iterations = flPlan.iterations();
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations)); LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
nextRequestTime = responseDataBuf.nextReqTime(); nextRequestTime = responseDataBuf.nextReqTime();
@ -226,14 +290,21 @@ public class FLLiteClient {
LOGGER.severe(Common.addTag("[startFLJob] startFLJob failed")); LOGGER.severe(Common.addTag("[startFLJob] startFLJob failed"));
break; break;
default: default:
LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range <SUCCESS, WAIT, FAILED, Restart>")); LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range " +
"<SUCCESS, WAIT, FAILED, Restart>"));
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
} }
return status; return status;
} }
/**
* Define the training process.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus localTrain() { public FLClientStatus localTrain() {
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "====================================")); LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration +
"===================================="));
status = FLClientStatus.SUCCESS; status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED; retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
@ -254,12 +325,22 @@ public class FLLiteClient {
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
} }
} else {
LOGGER.severe(Common.addTag("[train] the flName is not valid"));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
} }
return status; return status;
} }
/**
* Send serialized request message of updateModel to server.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus updateModel() { public FLClientStatus updateModel() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
UpdateModel updateModelBuf = UpdateModel.getInstance(); UpdateModel updateModelBuf = UpdateModel.getInstance();
byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize); byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize);
if (updateModelBuf.getStatus() == FLClientStatus.FAILED) { if (updateModelBuf.getStatus() == FLClientStatus.FAILED) {
@ -270,8 +351,9 @@ public class FLLiteClient {
long start = Common.startTime("single updateModel"); long start = Common.startTime("single updateModel");
LOGGER.info(Common.addTag("[updateModel] the request message length: " + updateModelBuffer.length)); LOGGER.info(Common.addTag("[updateModel] the request message length: " + updateModelBuffer.length));
byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer); byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(message)) {
LOGGER.info(Common.addTag("[updateModel] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[updateModel] the server is not ready now, need wait some time and request" +
" again"));
status = FLClientStatus.RESTART; status = FLClientStatus.RESTART;
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
nextRequestTime = ""; nextRequestTime = "";
@ -288,23 +370,31 @@ public class FLLiteClient {
} }
LOGGER.info(Common.addTag("[updateModel] get response from server ok!")); LOGGER.info(Common.addTag("[updateModel] get response from server ok!"));
} catch (IOException e) { } catch (IOException e) {
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage())); LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " +
e.getMessage()));
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
} }
return status; return status;
} }
/**
* Send serialized request message of getModel to server.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus getModel() { public FLClientStatus getModel() {
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
GetModel getModelBuf = GetModel.getInstance(); GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration); byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration);
try { try {
long start = Common.startTime("single getModel"); long start = Common.startTime("single getModel");
LOGGER.info(Common.addTag("[getModel] the request message length: " + buffer.length)); LOGGER.info(Common.addTag("[getModel] the request message length: " + buffer.length));
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer); byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(message)) {
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
"again"));
status = FLClientStatus.WAIT; status = FLClientStatus.WAIT;
return status; return status;
} }
@ -327,6 +417,12 @@ public class FLLiteClient {
return status; return status;
} }
/**
* Obtain the weight of the model before training.
*
* @param map a map to store the weight of the model.
* @return the weight.
*/
public static synchronized Map<String, float[]> getOldMapCopy(Map<String, float[]> map) { public static synchronized Map<String, float[]> getOldMapCopy(Map<String, float[]> map) {
if (mapBeforeTrain == null) { if (mapBeforeTrain == null) {
Map<String, float[]> copyMap = new TreeMap<>(); Map<String, float[]> copyMap = new TreeMap<>();
@ -334,7 +430,8 @@ public class FLLiteClient {
float[] data = map.get(key); float[] data = map.get(key);
int dataLen = data.length; int dataLen = data.length;
float[] weights = new float[dataLen]; float[] weights = new float[dataLen];
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) { if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) &&
(key.indexOf("learning") < 0)) {
for (int j = 0; j < dataLen; j++) { for (int j = 0; j < dataLen; j++) {
float weight = data[j]; float weight = data[j];
weights[j] = weight; weights[j] = weight;
@ -348,7 +445,8 @@ public class FLLiteClient {
float[] data = map.get(key); float[] data = map.get(key);
float[] copyData = mapBeforeTrain.get(key); float[] copyData = mapBeforeTrain.get(key);
int dataLen = data.length; int dataLen = data.length;
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) { if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) &&
(key.indexOf("learning") < 0)) {
for (int j = 0; j < dataLen; j++) { for (int j = 0; j < dataLen; j++) {
copyData[j] = data[j]; copyData[j] = data[j];
} }
@ -358,18 +456,25 @@ public class FLLiteClient {
return mapBeforeTrain; return mapBeforeTrain;
} }
/**
* Obtain pairwise mask and individual mask.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus getFeatureMask() { public FLClientStatus getFeatureMask() {
FLClientStatus curStatus; FLClientStatus curStatus;
switch (localFLParameter.getEncryptLevel()) { switch (localFLParameter.getEncryptLevel()) {
case PW_ENCRYPT: case PW_ENCRYPT:
LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" + localFLParameter.getEncryptLevel().toString() + ">")); LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" +
localFLParameter.getEncryptLevel().toString() + ">"));
secureProtocol.setPWParameter(iteration, minSecretNum, prime, featureSize); secureProtocol.setPWParameter(iteration, minSecretNum, prime, featureSize);
curStatus = secureProtocol.pwCreateMask(); curStatus = secureProtocol.pwCreateMask();
if (curStatus == FLClientStatus.RESTART) { if (curStatus == FLClientStatus.RESTART) {
nextRequestTime = secureProtocol.getNextRequestTime(); nextRequestTime = secureProtocol.getNextRequestTime();
} }
retCode = secureProtocol.getRetCode(); retCode = secureProtocol.getRetCode();
LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" + localFLParameter.getEncryptLevel().toString() + "> : " + curStatus)); LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" +
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
return curStatus; return curStatus;
case DP_ENCRYPT: case DP_ENCRYPT:
Map<String, float[]> map = new HashMap<String, float[]>(); Map<String, float[]> map = new HashMap<String, float[]>();
@ -388,7 +493,7 @@ public class FLLiteClient {
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
LOGGER.info(Common.addTag("[Encrypt] set parameters for DPEncrypt!")); LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
case NOT_ENCRYPT: case NOT_ENCRYPT:
retCode = ResponseCode.SUCCEED; retCode = ResponseCode.SUCCEED;
@ -401,6 +506,11 @@ public class FLLiteClient {
} }
} }
/**
* Reconstruct the secrets used for unmasking model weights.
*
* @return current status code in client.
*/
public FLClientStatus unMasking() { public FLClientStatus unMasking() {
FLClientStatus curStatus; FLClientStatus curStatus;
switch (localFLParameter.getEncryptLevel()) { switch (localFLParameter.getEncryptLevel()) {
@ -413,7 +523,7 @@ public class FLLiteClient {
} }
return curStatus; return curStatus;
case DP_ENCRYPT: case DP_ENCRYPT:
LOGGER.info(Common.addTag("[Encrypt] DPEncrypt do not need unmasking")); LOGGER.info(Common.addTag("[Encrypt] DP_ENCRYPT do not need unmasking"));
retCode = ResponseCode.SUCCEED; retCode = ResponseCode.SUCCEED;
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
case NOT_ENCRYPT: case NOT_ENCRYPT:
@ -427,18 +537,26 @@ public class FLLiteClient {
} }
} }
/**
* Evaluate model after getting model from server.
*
* @return the status code in client.
*/
public FLClientStatus evaluateModel() { public FLClientStatus evaluateModel() {
status = FLClientStatus.SUCCESS; status = FLClientStatus.SUCCESS;
retCode = ResponseCode.SUCCEED; retCode = ResponseCode.SUCCEED;
LOGGER.info(Common.addTag("===================================evaluate model after getting model from server===================================")); LOGGER.info(Common.addTag("===================================evaluate model after getting model from " +
"server==================================="));
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
float acc = 0; float acc = 0;
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) { if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlInferBert alInferBert = AlInferBert.getInstance(); AlInferBert alInferBert = AlInferBert.getInstance();
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true); int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
flParameter.getIdsFile(), true);
if (dataSize <= 0) { if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the return dataSize<=0")); LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the " +
"return dataSize<=0"));
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
return status; return status;
@ -447,47 +565,66 @@ public class FLLiteClient {
} else { } else {
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod())); LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile()); int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
flParameter.getIdsFile());
if (dataSize <= 0) { if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0")); LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the " +
"return dataSize<=0"));
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
return status; return status;
} }
acc = alTrainBert.evalModel(); acc = alTrainBert.evalModel();
} }
if (acc == Float.NaN) { if (Float.isNaN(acc)) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <evalModel>: the return acc is NAN")); LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <evalModel>: the return acc is NAN"));
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
return status; return status;
} }
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile())); LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() +
" idsFile: " + flParameter.getIdsFile()));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc)); LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
} else if (flParameter.getFlName().equals(LENET)) { } else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0], flParameter.getTestDataset().split(",")[1]); if (flParameter.getTestDataset().split(",").length < 2) {
LOGGER.severe(Common.addTag("[evaluate] the set testDataPath for lenet is not valid, should be the " +
"format of <data.bin,label.bin> "));
status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError;
return status;
}
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0],
flParameter.getTestDataset().split(",")[1]);
if (dataSize <= 0) { if (dataSize <= 0) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return dataSize<=0")); LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return " +
"dataSize<=0"));
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
return status; return status;
} }
float acc = trainLenet.evalModel(); float acc = trainLenet.evalModel();
if (acc == Float.NaN) { if (Float.isNaN(acc)) {
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc is NAN")); LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc" +
" is NAN"));
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
return status; return status;
} }
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset().split(",")[0] + " labelPath: " + flParameter.getTestDataset().split(",")[1])); LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
flParameter.getTestDataset().split(",")[0] + " labelPath: " +
flParameter.getTestDataset().split(",")[1]));
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc)); LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
} }
return status; return status;
} }
/** /**
* @param dataPath, train or test dataset and label set * Set date path.
*
* @param dataPath, train or test dataset and label set.
* @return date size.
*/ */
public int setInput(String dataPath) { public int setInput(String dataPath) {
retCode = ResponseCode.SUCCEED; retCode = ResponseCode.SUCCEED;
@ -496,15 +633,18 @@ public class FLLiteClient {
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
dataSize = alTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile()); dataSize = alTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile())); LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " " +
"vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
} else if (flParameter.getFlName().equals(LENET)) { } else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
if (dataPath.split(",").length < 2) { if (dataPath.split(",").length < 2) {
LOGGER.info(Common.addTag("[set input] the set dataPath for lenet is not valid, should be the format of <data.bin,label.bin>")); LOGGER.severe(Common.addTag("[set input] the set dataPath for lenet is not valid, should be the " +
"format of <data.bin,label.bin> "));
return -1; return -1;
} }
dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]); dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]);
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " + +dataSize + " labelPath: " + dataPath.split(",")[1])); LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " +
dataSize + " labelPath: " + dataPath.split(",")[1]));
} }
if (dataSize <= 0) { if (dataSize <= 0) {
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
@ -513,36 +653,48 @@ public class FLLiteClient {
return dataSize; return dataSize;
} }
/**
* Initialization session.
*
* @return the status code in client.
*/
public FLClientStatus initSession() { public FLClientStatus initSession() {
int tag = 0; int tag = 0;
retCode = ResponseCode.SUCCEED; retCode = ResponseCode.SUCCEED;
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
"Train Session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1")); LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
"is -1"));
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session=============")); LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " +
"Create inference Session============="));
AlInferBert alInferBert = AlInferBert.getInstance(); AlInferBert alInferBert = AlInferBert.getInstance();
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) { } else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
"Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true); tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
} }
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1")); LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is " +
"-1"));
retCode = ResponseCode.RequestError; retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
} }
@Override /**
protected void finalize() { * Free session.
*/
protected void freeSession() {
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("===========free train session=============")); LOGGER.info(Common.addTag("===========free train session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
@ -558,5 +710,4 @@ public class FLLiteClient {
SessionUtil.free(trainLenet.getTrainSession()); SessionUtil.free(trainLenet.getTrainSession());
} }
} }
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,22 +13,36 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient;
import java.util.logging.Logger; package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT; import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import java.util.Arrays;
import java.util.UUID;
import java.util.logging.Logger;
/**
* Defines global parameters used during federated learning and these parameters are provided for users to set.
*
* @since 2021-06-30
*/
public class FLParameter { public class FLParameter {
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString()); private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
/**
* The timeout interval for communication on the device.
*/
public static final int TIME_OUT = 100; public static final int TIME_OUT = 100;
/**
* The waiting time of repeated requests.
*/
public static final int SLEEP_TIME = 1000; public static final int SLEEP_TIME = 1000;
private static volatile FLParameter flParameter;
private String hostName; private String domainName;
private String certPath; private String certPath;
private boolean useHttps = false;
private String trainDataset; private String trainDataset;
private String vocabFile = "null"; private String vocabFile = "null";
private String idsFile = "null"; private String idsFile = "null";
@ -37,22 +51,21 @@ public class FLParameter {
private String trainModelPath; private String trainModelPath;
private String inferModelPath; private String inferModelPath;
private String clientID; private String clientID;
private String ip;
private int port;
private boolean useSSL = false; private boolean useSSL = false;
private int timeOut; private int timeOut;
private int sleepTime; private int sleepTime;
private boolean useElb = false; private boolean ifUseElb = false;
private int serverNum = 1; private int serverNum = 1;
private boolean timer = true; private FLParameter() {
private int timeWindow = 6000; clientID = UUID.randomUUID().toString();
private int reRequestNum = timeWindow / SLEEP_TIME + 1; }
private static volatile FLParameter flParameter;
private FLParameter() {}
/**
* Get the singleton object of the class FLParameter.
*
* @return the singleton object of the class FLParameter.
*/
public static FLParameter getInstance() { public static FLParameter getInstance() {
FLParameter localRef = flParameter; FLParameter localRef = flParameter;
if (localRef == null) { if (localRef == null) {
@ -66,95 +79,100 @@ public class FLParameter {
return localRef; return localRef;
} }
public String getHostName() { public String getDomainName() {
if ("".equals(hostName) || hostName.isEmpty()) { if (domainName == null || domainName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <hostName> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is null or empty, please set it " +
throw new RuntimeException(); "before use"));
throw new IllegalArgumentException();
} }
return hostName; return domainName;
} }
public void setHostName(String hostName) { public void setDomainName(String domainName) {
this.hostName = hostName; if (domainName == null || domainName.isEmpty() || (!("https".equals(domainName.split(":")[0]) || "http".equals(domainName.split(":")[0])))) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is not valid, it should be like " +
"as https://...... or http://......, please check it before set"));
throw new IllegalArgumentException();
}
this.domainName = domainName;
} }
public String getCertPath() { public String getCertPath() {
if ("".equals(certPath) || certPath.isEmpty()) { if (certPath == null || certPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null or empty, please set it " +
throw new RuntimeException(); "before use"));
throw new IllegalArgumentException();
} }
return certPath; return certPath;
} }
public void setCertPath(String certPath) { public void setCertPath(String certPath) {
certPath = Common.getRealPath(certPath); String realCertPath = Common.getRealPath(certPath);
if (Common.checkPath(certPath)) { if (Common.checkPath(realCertPath)) {
this.certPath = certPath; this.certPath = realCertPath;
} else { } else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it before set")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it " +
throw new RuntimeException(); "before set"));
throw new IllegalArgumentException();
} }
} }
public boolean isUseHttps() {
return useHttps;
}
public void setUseHttps(boolean useHttps) {
this.useHttps = useHttps;
}
public String getTrainDataset() { public String getTrainDataset() {
if ("".equals(trainDataset) || trainDataset.isEmpty()) { if (trainDataset == null || trainDataset.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null or empty, please set " +
throw new RuntimeException(); "it before use"));
throw new IllegalArgumentException();
} }
return trainDataset; return trainDataset;
} }
public void setTrainDataset(String trainDataset) { public void setTrainDataset(String trainDataset) {
trainDataset = Common.getRealPath(trainDataset); String realTrainDataset = Common.getRealPath(trainDataset);
if (Common.checkPath(trainDataset)) { if (Common.checkPath(realTrainDataset)) {
this.trainDataset = trainDataset; this.trainDataset = realTrainDataset;
} else { } else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it before set")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it " +
throw new RuntimeException(); "before set"));
throw new IllegalArgumentException();
} }
} }
public String getVocabFile() { public String getVocabFile() {
if ("null".equals(vocabFile) && ALBERT.equals(flName)) { if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before " +
throw new RuntimeException(); "use"));
throw new IllegalArgumentException();
} }
return vocabFile; return vocabFile;
} }
public void setVocabFile(String vocabFile) { public void setVocabFile(String vocabFile) {
vocabFile = Common.getRealPath(vocabFile); String realVocabFile = Common.getRealPath(vocabFile);
if (Common.checkPath(vocabFile)) { if (Common.checkPath(realVocabFile)) {
this.vocabFile = vocabFile; this.vocabFile = realVocabFile;
} else { } else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it before set")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it " +
throw new RuntimeException(); "before set"));
throw new IllegalArgumentException();
} }
} }
public String getIdsFile() { public String getIdsFile() {
if ("null".equals(idsFile) && ALBERT.equals(flName)) { if ("null".equals(idsFile) && ALBERT.equals(flName)) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
throw new RuntimeException(); throw new IllegalArgumentException();
} }
return idsFile; return idsFile;
} }
public void setIdsFile(String idsFile) { public void setIdsFile(String idsFile) {
idsFile = Common.getRealPath(idsFile); String realIdsFile = Common.getRealPath(idsFile);
if (Common.checkPath(idsFile)) { if (Common.checkPath(realIdsFile)) {
this.idsFile = idsFile; this.idsFile = realIdsFile;
} else { } else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it before set")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it " +
throw new RuntimeException(); "before set"));
throw new IllegalArgumentException();
} }
} }
@ -163,19 +181,21 @@ public class FLParameter {
} }
public void setTestDataset(String testDataset) { public void setTestDataset(String testDataset) {
testDataset = Common.getRealPath(testDataset); String realTestDataset = Common.getRealPath(testDataset);
if (Common.checkPath(testDataset)) { if (Common.checkPath(realTestDataset)) {
this.testDataset = testDataset; this.testDataset = realTestDataset;
} else { } else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it before set")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it " +
throw new RuntimeException(); "before set"));
throw new IllegalArgumentException();
} }
} }
public String getFlName() { public String getFlName() {
if ("".equals(flName) || flName.isEmpty()) { if (flName == null || flName.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null or empty, please set it " +
throw new RuntimeException(); "before use"));
throw new IllegalArgumentException();
} }
return flName; return flName;
} }
@ -184,61 +204,50 @@ public class FLParameter {
if (Common.checkFLName(flName)) { if (Common.checkFLName(flName)) {
this.flName = flName; this.flName = flName;
} else { } else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in flNameTrustList, please check it before set")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in FL_NAME_TRUST_LIST: " +
throw new RuntimeException(); Arrays.toString(Common.FL_NAME_TRUST_LIST.toArray(new String[0])) + ", please check it before " +
"set"));
throw new IllegalArgumentException();
} }
} }
public String getTrainModelPath() { public String getTrainModelPath() {
if ("".equals(trainModelPath) || trainModelPath.isEmpty()) { if (trainModelPath == null || trainModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null or empty, please set" +
throw new RuntimeException(); " it before use"));
throw new IllegalArgumentException();
} }
return trainModelPath; return trainModelPath;
} }
public void setTrainModelPath(String trainModelPath) { public void setTrainModelPath(String trainModelPath) {
trainModelPath = Common.getRealPath(trainModelPath); String realTrainModelPath = Common.getRealPath(trainModelPath);
if (Common.checkPath(trainModelPath)) { if (Common.checkPath(realTrainModelPath)) {
this.trainModelPath = trainModelPath; this.trainModelPath = realTrainModelPath;
} else { } else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check it before set")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check " +
throw new RuntimeException(); "it before set"));
throw new IllegalArgumentException();
} }
} }
public String getInferModelPath() { public String getInferModelPath() {
if ("".equals(inferModelPath) || inferModelPath.isEmpty()) { if (inferModelPath == null || inferModelPath.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null or empty, please set" +
throw new RuntimeException(); " it before use"));
throw new IllegalArgumentException();
} }
return inferModelPath; return inferModelPath;
} }
public void setInferModelPath(String inferModelPath) { public void setInferModelPath(String inferModelPath) {
inferModelPath = Common.getRealPath(inferModelPath); String realInferModelPath = Common.getRealPath(inferModelPath);
if (Common.checkPath(inferModelPath)) { if (Common.checkPath(realInferModelPath)) {
this.inferModelPath = inferModelPath; this.inferModelPath = realInferModelPath;
} else { } else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check it before set")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check " +
throw new RuntimeException(); "it before set"));
} throw new IllegalArgumentException();
}
public String getIp() {
if ("".equals(ip) || ip.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is null, please set it before use"));
throw new RuntimeException();
}
return ip;
}
public void setIp(String ip) {
if (Common.checkIP(ip)) {
this.ip = ip;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is not valid, please check it before set"));
throw new RuntimeException();
} }
} }
@ -250,23 +259,6 @@ public class FLParameter {
this.useSSL = useSSL; this.useSSL = useSSL;
} }
public int getPort() {
if (port == 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is null, please set it before use"));
throw new RuntimeException();
}
return port;
}
public void setPort(int port) {
if (Common.checkPort(port)) {
this.port = port;
} else {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is not valid, please check it before set"));
throw new RuntimeException();
}
}
public int getTimeOut() { public int getTimeOut() {
return timeOut; return timeOut;
} }
@ -284,17 +276,18 @@ public class FLParameter {
} }
public boolean isUseElb() { public boolean isUseElb() {
return useElb; return ifUseElb;
} }
public void setUseElb(boolean useElb) { public void setUseElb(boolean ifUseElb) {
this.useElb = useElb; this.ifUseElb = ifUseElb;
} }
public int getServerNum() { public int getServerNum() {
if (serverNum <= 0) { if (serverNum <= 0) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> is <= 0, it should be > 0, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> <= 0, it should be > 0, please " +
throw new RuntimeException(); "set it before use"));
throw new IllegalArgumentException();
} }
return serverNum; return serverNum;
} }
@ -303,40 +296,11 @@ public class FLParameter {
this.serverNum = serverNum; this.serverNum = serverNum;
} }
public boolean isTimer() {
return timer;
}
public void setTimer(boolean timer) {
this.timer = timer;
}
public int getTimeWindow() {
return timeWindow;
}
public void setTimeWindow(int timeWindow) {
this.timeWindow = timeWindow;
}
public int getReRequestNum() {
return reRequestNum;
}
public void setReRequestNum(int reRequestNum) {
this.reRequestNum = reRequestNum;
}
public String getClientID() { public String getClientID() {
if ("".equals(clientID) || clientID.isEmpty()) { if (clientID == null || clientID.isEmpty()) {
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null, please set it before use")); LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check"));
throw new RuntimeException(); throw new IllegalArgumentException();
} }
return clientID; return clientID;
} }
public void setClientID(String clientID) {
this.clientID = clientID;
}
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,13 +13,19 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlInferBert; import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet; import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap; import mindspore.schema.FeatureMap;
import mindspore.schema.RequestGetModel; import mindspore.schema.RequestGetModel;
import mindspore.schema.ResponseCode; import mindspore.schema.ResponseCode;
@ -29,57 +35,30 @@ import java.util.ArrayList;
import java.util.Date; import java.util.Date;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT; /**
import static com.mindspore.flclient.LocalFLParameter.LENET; * Define the serialization method, handle the response message returned from server for getModel request.
*
* @since 2021-06-30
*/
public class GetModel { public class GetModel {
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
private static volatile GetModel getModel;
static { static {
System.loadLibrary("mindspore-lite-jni"); System.loadLibrary("mindspore-lite-jni");
} }
class RequestGetModelBuilder {
private FlatBufferBuilder builder;
private int nameOffset = 0;
private int iteration = 0;
private int timeStampOffset = 0;
public RequestGetModelBuilder() {
builder = new FlatBufferBuilder();
}
public RequestGetModelBuilder flName(String name) {
this.nameOffset = this.builder.createString(name);
return this;
}
public RequestGetModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timeStampOffset = builder.createString(String.valueOf(time));
return this;
}
public RequestGetModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
public byte[] build() {
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
builder.finish(root);
return builder.sizedByteArray();
}
}
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
private static volatile GetModel getModel;
private GetModel() {
}
private FLParameter flParameter = FLParameter.getInstance(); private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private GetModel() {
}
/**
* Get the singleton object of the class GetModel.
*
* @return the singleton object of the class GetModel.
*/
public static GetModel getInstance() { public static GetModel getInstance() {
GetModel localRef = getModel; GetModel localRef = getModel;
if (localRef == null) { if (localRef == null) {
@ -93,7 +72,18 @@ public class GetModel {
return localRef; return localRef;
} }
/**
* Get a flatBuffer builder of RequestGetModel.
*
* @param name the model name.
* @param iteration current iteration of federated learning task.
* @return the flatBuffer builder of RequestGetModel in byte[] format.
*/
public byte[] getRequestGetModel(String name, int iteration) { public byte[] getRequestGetModel(String name, int iteration) {
if (name == null || name.isEmpty()) {
LOGGER.severe(Common.addTag("[GetModel] the input parameter of <name> is null or empty, please check!"));
throw new IllegalArgumentException();
}
RequestGetModelBuilder builder = new RequestGetModelBuilder(); RequestGetModelBuilder builder = new RequestGetModelBuilder();
return builder.iteration(iteration).flName(name).time().build(); return builder.iteration(iteration).flName(name).time().build();
} }
@ -107,6 +97,10 @@ public class GetModel {
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>(); ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) { for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i); FeatureMap feature = responseDataBuf.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname(); String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) { if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature); albertFeatureMaps.add(feature);
@ -116,36 +110,46 @@ public class GetModel {
} else { } else {
continue; continue;
} }
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength:" +
" " + feature.dataLength()));
} }
int tag = 0; int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------")); LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " +
"model-----------------"));
AlInferBert alInferBert = AlInferBert.getInstance(); AlInferBert alInferBert = AlInferBert.getInstance();
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps); tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
inferFeatureMaps);
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>")); LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------")); LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps); tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
albertFeatureMaps);
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>")); LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) { } else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>")); LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>(); ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) { for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i); FeatureMap feature = responseDataBuf.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname(); String featureName = feature.weightFullname();
featureMaps.add(feature); featureMaps.add(feature);
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength())); LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
feature.dataLength()));
} }
int tag = 0; int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------")); LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps); tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
featureMaps);
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>")); LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
@ -159,9 +163,14 @@ public class GetModel {
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>(); ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) { for (int i = 0; i < fmCount; i++) {
FeatureMap feature = responseDataBuf.featureMap(i); FeatureMap feature = responseDataBuf.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname(); String featureName = feature.weightFullname();
featureMaps.add(feature); featureMaps.add(feature);
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength())); LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
feature.dataLength()));
} }
int tag = 0; int tag = 0;
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------")); LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
@ -174,7 +183,12 @@ public class GetModel {
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
} }
/**
* Handle the response message returned from server.
*
* @param responseDataBuf the response message returned from server.
* @return the status code corresponding to the response message.
*/
public FLClientStatus doResponse(ResponseGetModel responseDataBuf) { public FLClientStatus doResponse(ResponseGetModel responseDataBuf) {
LOGGER.info(Common.addTag("[getModel] ==========get model content is:================")); LOGGER.info(Common.addTag("[getModel] ==========get model content is:================"));
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode())); LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode()));
@ -186,13 +200,15 @@ public class GetModel {
switch (retCode) { switch (retCode) {
case (ResponseCode.SUCCEED): case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[getModel] getModel response success")); LOGGER.info(Common.addTag("[getModel] getModel response success"));
if (ALBERT.equals(flParameter.getFlName())) { if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>")); LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
status = parseResponseAlbert(responseDataBuf); status = parseResponseAlbert(responseDataBuf);
} else if (LENET.equals(flParameter.getFlName())) { } else if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>")); LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
status = parseResponseLenet(responseDataBuf); status = parseResponseLenet(responseDataBuf);
} else {
LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert"));
throw new IllegalArgumentException();
} }
return status; return status;
case (ResponseCode.SucNotReady): case (ResponseCode.SucNotReady):
@ -211,4 +227,42 @@ public class GetModel {
} }
} }
class RequestGetModelBuilder {
private FlatBufferBuilder builder;
private int nameOffset = 0;
private int iteration = 0;
private int timeStampOffset = 0;
public RequestGetModelBuilder() {
builder = new FlatBufferBuilder();
}
private RequestGetModelBuilder flName(String name) {
if (name == null || name.isEmpty()) {
LOGGER.severe(Common.addTag("[GetModel] the input parameter of <name> is null or empty, please " +
"check!"));
throw new IllegalArgumentException();
}
this.nameOffset = this.builder.createString(name);
return this;
}
private RequestGetModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timeStampOffset = builder.createString(String.valueOf(time));
return this;
}
private RequestGetModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
private byte[] build() {
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
builder.finish(root);
return builder.sizedByteArray();
}
}
} }

View File

@ -1,25 +1,42 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import java.io.IOException; import java.io.IOException;
/**
* Define asynchronous communication call back interface.
*
* @since 2021-06-30
*/
public interface IAsyncCallBack { public interface IAsyncCallBack {
public FLClientStatus onFailure(IOException exception); /**
* Automatically invoked when the request fails.
*
* @param exception the catch exception.
* @return the status code in client.
*/
FLClientStatus onFailure(IOException exception);
public FLClientStatus onResponse(byte[] msg); /**
* Automatically invoked when a response message is processed.
*
* @param msg the response message.
* @return the status code in client.
*/
FLClientStatus onResponse(byte[] msg);
} }

View File

@ -1,33 +1,54 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
/** /**
* @author smurf * Define basic communication interface.
* *
* @since 2021-06-30
*/ */
public interface IFLCommunication { public interface IFLCommunication {
/**
* Sets the timeout interval for communication on the device.
*
* @param timeout the timeout interval for communication on the device.
* @throws TimeoutException catch TimeoutException.
*/
void setTimeOut(int timeout) throws TimeoutException;
public void setTimeOut(int timeout) throws TimeoutException; /**
* Synchronization request function.
public byte[] syncRequest(String url, byte[] msg) throws Exception; *
* @param url the URL for device-sever interaction set by user.
public void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception; * @param msg the message need to be sent to server.
} * @return the response message.
* @throws Exception catch Exception.
*/
byte[] syncRequest(String url, byte[] msg) throws Exception;
/**
* Asynchronous request function.
*
* @param url the URL for device-sever interaction set by user.
* @param msg the message need to be sent to server.
* @param callBack the call back object.
* @throws Exception catch Exception.
*/
void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception;
}

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,22 +13,30 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
/**
* Define job result callback function interface.
*
* @since 2021-06-30
*/
public interface IFLJobResultCallback { public interface IFLJobResultCallback {
/** /**
* Called at the end of an iteration for Fl job * Called at the end of an iteration for Fl job
* @param modelName the name of model *
* @param modelName the name of model
* @param iterationSeq Iteration number * @param iterationSeq Iteration number
* @param resultCode Status Code * @param resultCode Status Code
*/ */
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode); void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode);
/** /**
* Called on completion for Fl job * Called on completion for Fl job
* @param modelName the name of model *
* @param modelName the name of model
* @param iterationCount total Iteration numbers * @param iterationCount total Iteration numbers
* @param resultCode Status Code * @param resultCode Status Code
*/ */
public void onFlJobFinished(String modelName, int iterationCount, int resultCode); void onFlJobFinished(String modelName, int iterationCount, int resultCode);
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,28 +13,60 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import org.bouncycastle.math.ec.rfc7748.X25519;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.logging.Logger; import java.util.logging.Logger;
/**
* Defines global parameters used internally during federated learning.
*
* @since 2021-06-30
*/
public class LocalFLParameter { public class LocalFLParameter {
private static final Logger LOGGER = Logger.getLogger(LocalFLParameter.class.toString()); private static final Logger LOGGER = Logger.getLogger(LocalFLParameter.class.toString());
/**
* Seed length used to generate random perturbations
*/
public static final int SEED_SIZE = 32; public static final int SEED_SIZE = 32;
public static final int IVEC_LEN = 16;
/**
* The length of IV value
*/
public static final int I_VEC_LEN = 16;
/**
* The length of salt value
*/
public static final int SALT_SIZE = 32;
/**
* the key length
*/
public static final int KEY_LEN = X25519.SCALAR_SIZE;
/**
* The model name supported by federated learning tasks: "lenet".
*/
public static final String LENET = "lenet"; public static final String LENET = "lenet";
/**
* The model name supported by federated learning tasks: "albert".
*/
public static final String ALBERT = "albert"; public static final String ALBERT = "albert";
private static volatile LocalFLParameter localFLParameter;
private List<String> classifierWeightName = new ArrayList<>(); private List<String> classifierWeightName = new ArrayList<>();
private List<String> albertWeightName = new ArrayList<>(); private List<String> albertWeightName = new ArrayList<>();
private String flID; private String flID;
private String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString(); private String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString();
private String earlyStopMod = EarlyStopMod.NOT_EARLY_STOP.toString(); private String earlyStopMod = EarlyStopMod.NOT_EARLY_STOP.toString();
private String serverMod = ServerMod.HYBRID_TRAINING.toString(); private String serverMod = ServerMod.HYBRID_TRAINING.toString();
private String safeMod = "The cluster is in safemode.";
private static volatile LocalFLParameter localFLParameter;
private LocalFLParameter() { private LocalFLParameter() {
// set classifierWeightName albertWeightName // set classifierWeightName albertWeightName
@ -42,6 +74,11 @@ public class LocalFLParameter {
Common.setAlbertWeightName(albertWeightName); Common.setAlbertWeightName(albertWeightName);
} }
/**
* Get the singleton object of the class LocalFLParameter.
*
* @return the singleton object of the class LocalFLParameter.
*/
public static LocalFLParameter getInstance() { public static LocalFLParameter getInstance() {
LocalFLParameter localRef = localFLParameter; LocalFLParameter localRef = localFLParameter;
if (localRef == null) { if (localRef == null) {
@ -57,8 +94,9 @@ public class LocalFLParameter {
public List<String> getClassifierWeightName() { public List<String> getClassifierWeightName() {
if (classifierWeightName.isEmpty()) { if (classifierWeightName.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use")); LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please " +
throw new RuntimeException(); "set it before use"));
throw new IllegalArgumentException();
} }
return classifierWeightName; return classifierWeightName;
} }
@ -69,8 +107,9 @@ public class LocalFLParameter {
public List<String> getAlbertWeightName() { public List<String> getAlbertWeightName() {
if (albertWeightName.isEmpty()) { if (albertWeightName.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use")); LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please " +
throw new RuntimeException(); "set it before use"));
throw new IllegalArgumentException();
} }
return albertWeightName; return albertWeightName;
} }
@ -80,14 +119,20 @@ public class LocalFLParameter {
} }
public String getFlID() { public String getFlID() {
if ("".equals(flID) || flID == null) { if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before use")); LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before " +
throw new RuntimeException(); "use"));
throw new IllegalArgumentException();
} }
return flID; return flID;
} }
public void setFlID(String flID) { public void setFlID(String flID) {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please check it before " +
"set"));
throw new IllegalArgumentException();
}
this.flID = flID; this.flID = flID;
} }
@ -96,6 +141,18 @@ public class LocalFLParameter {
} }
public void setEncryptLevel(String encryptLevel) { public void setEncryptLevel(String encryptLevel) {
if (encryptLevel == null || encryptLevel.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is null, please check it " +
"before set"));
throw new IllegalArgumentException();
}
if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) &&
(!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) &&
(!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is " + encryptLevel + " ," +
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before set"));
throw new IllegalArgumentException();
}
this.encryptLevel = encryptLevel; this.encryptLevel = encryptLevel;
} }
@ -104,6 +161,19 @@ public class LocalFLParameter {
} }
public void setEarlyStopMod(String earlyStopMod) { public void setEarlyStopMod(String earlyStopMod) {
if (earlyStopMod == null || earlyStopMod.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is null, please check it " +
"before set"));
throw new IllegalArgumentException();
}
if ((!EarlyStopMod.NOT_EARLY_STOP.toString().equals(earlyStopMod)) &&
(!EarlyStopMod.LOSS_ABS.toString().equals(earlyStopMod)) &&
(!EarlyStopMod.LOSS_DIFF.toString().equals(earlyStopMod)) &&
(!EarlyStopMod.WEIGHT_DIFF.toString().equals(earlyStopMod))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is " + earlyStopMod + " ," +
" it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before set"));
throw new IllegalArgumentException();
}
this.earlyStopMod = earlyStopMod; this.earlyStopMod = earlyStopMod;
} }
@ -112,14 +182,17 @@ public class LocalFLParameter {
} }
public void setServerMod(String serverMod) { public void setServerMod(String serverMod) {
if (serverMod == null || serverMod.isEmpty()) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is null, please check it " +
"before set"));
throw new IllegalArgumentException();
}
if ((!ServerMod.HYBRID_TRAINING.toString().equals(serverMod)) &&
(!ServerMod.FEDERATED_LEARNING.toString().equals(serverMod))) {
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is " + serverMod + " , it " +
"must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before set"));
throw new IllegalArgumentException();
}
this.serverMod = serverMod; this.serverMod = serverMod;
} }
public String getSafeMod() {
return safeMod;
}
public void setSafeMod(String safeMod) {
this.safeMod = safeMod;
}
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,33 +13,73 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.security.InvalidKeyException;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.SignatureException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.logging.Logger;
import javax.net.ssl.HostnameVerifier; import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager; import javax.net.ssl.X509TrustManager;
import java.io.FileInputStream;
import java.io.InputStream;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.SignatureException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.logging.Logger;
/**
* Define SSL socket factory tools for https communication.
*
* @since 2021-06-30
*/
public class SSLSocketFactoryTools { public class SSLSocketFactoryTools {
private static final Logger LOGGER = Logger.getLogger(SSLSocketFactory.class.toString()); private static final Logger LOGGER = Logger.getLogger(SSLSocketFactory.class.toString());
private static volatile SSLSocketFactoryTools sslSocketFactoryTools;
private FLParameter flParameter = FLParameter.getInstance(); private FLParameter flParameter = FLParameter.getInstance();
private X509Certificate x509Certificate; private X509Certificate x509Certificate;
private SSLSocketFactory sslSocketFactory; private SSLSocketFactory sslSocketFactory;
private SSLContext sslContext; private SSLContext sslContext;
private MyTrustManager myTrustManager; private MyTrustManager myTrustManager;
private static volatile SSLSocketFactoryTools sslSocketFactoryTools; private final HostnameVerifier hostnameVerifier = new HostnameVerifier() {
@Override
public boolean verify(String hostname, SSLSession session) {
if (hostname == null || hostname.isEmpty()) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <hostname> is null or empty, " +
"please check!"));
throw new IllegalArgumentException();
}
if (session == null) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <session> is null, please " +
"check!"));
throw new IllegalArgumentException();
}
String domainName = flParameter.getDomainName();
if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the <domainName> is null or not valid, it should" +
" be like as https://...... , please check!"));
throw new IllegalArgumentException();
}
if (domainName.split("//")[1].split(":").length < 2) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the format of <domainName> is not valid, it " +
"should be like as https://127.0.0.1:6666 when setting <useSSL> to true, please check!"));
throw new IllegalArgumentException();
}
String ip = domainName.split("//")[1].split(":")[0];
return hostname.equals(ip);
}
};
private SSLSocketFactoryTools() { private SSLSocketFactoryTools() {
initSslSocketFactory(); initSslSocketFactory();
@ -52,14 +92,19 @@ public class SSLSocketFactoryTools {
myTrustManager = new MyTrustManager(x509Certificate); myTrustManager = new MyTrustManager(x509Certificate);
sslContext.init(null, new TrustManager[]{ sslContext.init(null, new TrustManager[]{
myTrustManager myTrustManager
}, new java.security.SecureRandom()); }, Common.getSecureRandom());
sslSocketFactory = sslContext.getSocketFactory(); sslSocketFactory = sslContext.getSocketFactory();
} catch (NoSuchAlgorithmException | KeyManagementException ex) {
} catch (Exception e) { LOGGER.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " +
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " + e.getMessage())); ex.getMessage()));
} }
} }
/**
* Get the singleton object of the class SSLSocketFactoryTools.
*
* @return the singleton object of the class SSLSocketFactoryTools.
*/
public static SSLSocketFactoryTools getInstance() { public static SSLSocketFactoryTools getInstance() {
SSLSocketFactoryTools localRef = sslSocketFactoryTools; SSLSocketFactoryTools localRef = sslSocketFactoryTools;
if (localRef == null) { if (localRef == null) {
@ -73,29 +118,37 @@ public class SSLSocketFactoryTools {
return localRef; return localRef;
} }
public X509Certificate readCert(String assetName) { private X509Certificate readCert(String assetName) {
InputStream inputStream = null; if (assetName == null || assetName.isEmpty()) {
try { LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <assetName> is null or empty, " +
inputStream = new FileInputStream(assetName); "please check!"));
} catch (Exception e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of read inputStream in readCert: " + e.getMessage()));
return null; return null;
} }
InputStream inputStream = null;
X509Certificate cert = null; X509Certificate cert = null;
try { try {
inputStream = new FileInputStream(assetName);
CertificateFactory cf = CertificateFactory.getInstance("X.509"); CertificateFactory cf = CertificateFactory.getInstance("X.509");
cert = (X509Certificate) cf.generateCertificate(inputStream); Certificate certificate = cf.generateCertificate(inputStream);
} catch (Exception e) { if (certificate instanceof X509Certificate) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of creating CertificateFactory in readCert: " + e.getMessage())); cert = (X509Certificate) certificate;
} else {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] cf.generateCertificate(inputStream) can not " +
"convert to X509Certificate"));
}
} catch (FileNotFoundException | CertificateException ex) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch FileNotFoundException or CertificateException " +
"when creating " +
"CertificateFactory in readCert: " + ex.getMessage()));
} finally { } finally {
try { try {
if (inputStream != null) { if (inputStream != null) {
inputStream.close(); inputStream.close();
} }
} catch (Throwable ex) { } catch (IOException ex) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch IOException: " + ex.getMessage()));
} }
} }
return cert; return cert;
} }
@ -111,7 +164,6 @@ public class SSLSocketFactoryTools {
return myTrustManager; return myTrustManager;
} }
private static final class MyTrustManager implements X509TrustManager { private static final class MyTrustManager implements X509TrustManager {
X509Certificate cert; X509Certificate cert;
@ -126,27 +178,25 @@ public class SSLSocketFactoryTools {
@Override @Override
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
for (X509Certificate cert : chain) { for (X509Certificate cert : chain) {
// Make sure that it hasn't expired. // Make sure that it hasn't expired.
cert.checkValidity(); cert.checkValidity();
// Verify the certificate's public key chain. // Verify the certificate's public key chain.
try { try {
cert.verify(((X509Certificate) this.cert).getPublicKey()); cert.verify(this.cert.getPublicKey());
} catch (NoSuchAlgorithmException e) { } catch (NoSuchAlgorithmException e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage())); LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
throw new RuntimeException(); "NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage()));
} catch (InvalidKeyException e) { } catch (InvalidKeyException e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch InvalidKeyException in checkServerTrusted: " + e.getMessage())); LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
throw new RuntimeException(); "InvalidKeyException in checkServerTrusted: " + e.getMessage()));
} catch (NoSuchProviderException e) { } catch (NoSuchProviderException e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchProviderException in checkServerTrusted: " + e.getMessage())); LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
throw new RuntimeException(); "NoSuchProviderException in checkServerTrusted: " + e.getMessage()));
} catch (SignatureException e) { } catch (SignatureException e) {
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch SignatureException in checkServerTrusted: " + e.getMessage())); LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
throw new RuntimeException(); "SignatureException in checkServerTrusted: " + e.getMessage()));
} }
LOGGER.info(Common.addTag("checkServerTrusted success!")); LOGGER.info(Common.addTag("**********************checkServerTrusted success!**********************"));
} }
} }
@ -155,14 +205,4 @@ public class SSLSocketFactoryTools {
return new java.security.cert.X509Certificate[0]; return new java.security.cert.X509Certificate[0];
} }
} }
private final HostnameVerifier hostnameVerifier = new HostnameVerifier() {
@Override
public boolean verify(String hostname, SSLSession session) {
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] server hostname: " + flParameter.getHostName()));
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] client request hostname: " + hostname));
return hostname.equals(flParameter.getHostName());
}
};
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,127 +13,179 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet; import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap; import mindspore.schema.FeatureMap;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Random;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT; /**
import static com.mindspore.flclient.LocalFLParameter.LENET; * Defines encryption and decryption methods.
*
* @since 2021-06-30
*/
public class SecureProtocol { public class SecureProtocol {
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString()); private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
private static double deltaError = 1e-6d;
private static Map<String, float[]> modelMap;
private FLParameter flParameter = FLParameter.getInstance(); private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int iteration; private int iteration;
private CipherClient cipher; private CipherClient cipherClient;
private FLClientStatus status; private FLClientStatus status;
private float[] featureMask = new float[0]; private float[] featureMask = new float[0];
private double dpEps; private double dpEps;
private double dpDelta; private double dpDelta;
private double dpNormClip; private double dpNormClip;
private static double deltaError = 1e-6;
private static Map<String, float[]> modelMap;
private ArrayList<String> encryptFeatureName = new ArrayList<String>(); private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private int retCode; private int retCode;
/**
* Obtain current status code in client.
*
* @return current status code in client.
*/
public FLClientStatus getStatus() { public FLClientStatus getStatus() {
return status; return status;
} }
public float[] getFeatureMask() { /**
return featureMask; * Obtain retCode returned by server.
} *
* @return the retCode returned by server.
*/
public int getRetCode() { public int getRetCode() {
return retCode; return retCode;
} }
public SecureProtocol() { /**
} * Setting parameters for pairwise masking.
*
* @param iter current iteration of federated learning task.
* @param minSecretNum minimum number of secret fragments required to reconstruct a secret
* @param prime teh big prime number used to split secrets into pieces
* @param featureSize the total feature size in model
*/
public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) { public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) {
this.iteration = iter; if (prime == null || prime.length == 0) {
this.cipher = new CipherClient(iteration, minSecretNum, prime, featureSize); LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <prime> is null, please check!"));
} throw new IllegalArgumentException();
public FLClientStatus setDPParameter(int iter, double diffEps,
double diffDelta, double diffNorm, Map<String, float[]> map) {
try {
this.iteration = iter;
this.dpEps = diffEps;
this.dpDelta = diffDelta;
this.dpNormClip = diffNorm;
this.modelMap = map;
status = FLClientStatus.SUCCESS;
} catch (Exception e) {
LOGGER.severe(Common.addTag("[DPEncrypt] catch Exception in setDPParameter: " + e.getMessage()));
status = FLClientStatus.FAILED;
} }
return status; this.iteration = iter;
this.cipherClient = new CipherClient(iteration, minSecretNum, prime, featureSize);
} }
/**
* Setting parameters for differential privacy.
*
* @param iter current iteration of federated learning task.
* @param diffEps privacy budget eps of DP mechanism.
* @param diffDelta privacy budget delta of DP mechanism.
* @param diffNorm normClip factor of DP mechanism.
* @param map model weights.
* @return the status code corresponding to the response message.
*/
public FLClientStatus setDPParameter(int iter, double diffEps, double diffDelta, double diffNorm, Map<String,
float[]> map) {
this.iteration = iter;
this.dpEps = diffEps;
this.dpDelta = diffDelta;
this.dpNormClip = diffNorm;
this.modelMap = map;
return FLClientStatus.SUCCESS;
}
/**
* Obtain the feature names that needed to be encrypted.
*
* @return the feature names that needed to be encrypted.
*/
public ArrayList<String> getEncryptFeatureName() { public ArrayList<String> getEncryptFeatureName() {
return encryptFeatureName; return encryptFeatureName;
} }
/**
* Set the parameter encryptFeatureName.
*
* @param encryptFeatureName the feature names that needed to be encrypted.
*/
public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) { public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) {
this.encryptFeatureName = encryptFeatureName; this.encryptFeatureName = encryptFeatureName;
} }
/**
* Obtain the returned timestamp for next request from server.
*
* @return the timestamp for next request.
*/
public String getNextRequestTime() { public String getNextRequestTime() {
return cipher.getNextRequestTime(); return cipherClient.getNextRequestTime();
} }
/**
* Generate pairwise mask and individual mask.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus pwCreateMask() { public FLClientStatus pwCreateMask() {
LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "=============="); LOGGER.info(String.format("[PairWiseMask] ==============request flID: %s ==============",
localFLParameter.getFlID()));
// round 0 // round 0
status = cipher.exchangeKeys(); status = cipherClient.exchangeKeys();
retCode = cipher.getRetCode(); retCode = cipherClient.getRetCode();
LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============"); LOGGER.info(String.format("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: %s ",
"============", status));
if (status != FLClientStatus.SUCCESS) { if (status != FLClientStatus.SUCCESS) {
return status; return status;
} }
// round 1 // round 1
try { status = cipherClient.shareSecrets();
status = cipher.shareSecrets(); retCode = cipherClient.getRetCode();
retCode = cipher.getRetCode(); LOGGER.info(String.format("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: %s ",
LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "============="); "=============", status));
} catch (Exception e) {
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
status = FLClientStatus.FAILED;
}
if (status != FLClientStatus.SUCCESS) { if (status != FLClientStatus.SUCCESS) {
return status; return status;
} }
// round2 // round2
try { featureMask = cipherClient.doubleMaskingWeight();
featureMask = cipher.doubleMaskingWeight(); if (featureMask == null || featureMask.length <= 0) {
retCode = cipher.getRetCode(); LOGGER.severe(Common.addTag("[Encrypt] the returned featureMask from cipherClient.doubleMaskingWeight" +
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS============="); " is null, please check!"));
} catch (Exception e) { return FLClientStatus.FAILED;
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
status = FLClientStatus.FAILED;
} }
retCode = cipherClient.getRetCode();
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
return status; return status;
} }
/**
* Add the pairwise mask and individual mask to model weights.
*
* @param builder the FlatBufferBuilder object used for serialization model weights.
* @param trainDataSize trainDataSize tne size of train data set.
* @return the serialized model weights after adding masks.
*/
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) { public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
if (featureMask == null || featureMask.length == 0) { if (featureMask == null || featureMask.length == 0) {
LOGGER.severe("[Encrypt] feature mask is null, please check"); LOGGER.severe("[Encrypt] feature mask is null, please check");
return new int[0]; return new int[0];
} }
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length); LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
// get feature map // get feature map
Map<String, float[]> map = new HashMap<String, float[]>(); Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
@ -142,6 +194,9 @@ public class SecureProtocol {
} else if (flParameter.getFlName().equals(LENET)) { } else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
} else {
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
throw new IllegalArgumentException();
} }
int featureSize = encryptFeatureName.size(); int featureSize = encryptFeatureName.size();
int[] featuresMap = new int[featureSize]; int[] featuresMap = new int[featureSize];
@ -149,9 +204,13 @@ public class SecureProtocol {
for (int i = 0; i < featureSize; i++) { for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i); String key = encryptFeatureName.get(i);
float[] data = map.get(key); float[] data = map.get(key);
LOGGER.info("[Encrypt] feature name: " + key + " feature size: " + data.length); LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
for (int j = 0; j < data.length; j++) { for (int j = 0; j < data.length; j++) {
float rawData = data[j]; float rawData = data[j];
if (maskIndex >= featureMask.length) {
LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check");
return new int[0];
}
float maskData = rawData * trainDataSize + featureMask[maskIndex]; float maskData = rawData * trainDataSize + featureMask[maskIndex];
maskIndex += 1; maskIndex += 1;
data[j] = maskData; data[j] = maskData;
@ -164,17 +223,23 @@ public class SecureProtocol {
return featuresMap; return featuresMap;
} }
/**
* Reconstruct the secrets used for unmasking model weights.
*
* @return current status code in client.
*/
public FLClientStatus pwUnmasking() { public FLClientStatus pwUnmasking() {
status = cipher.reconstructSecrets(); // round3 status = cipherClient.reconstructSecrets(); // round3
retCode = cipher.getRetCode(); retCode = cipherClient.getRetCode();
LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "============="); LOGGER.info(String.format("[Encrypt] =============GetClientList+SendReconstructSecret: %s =============",
status));
return status; return status;
} }
private static float calculateErf(double x) { private static float calculateErf(double erfInput) {
double result = 0; double result = 0d;
int segmentNum = 10000; int segmentNum = 10000;
double deltaX = x / segmentNum; double deltaX = erfInput / segmentNum;
result += 1; result += 1;
for (int i = 1; i < segmentNum; i++) { for (int i = 1; i < segmentNum; i++) {
result += 2 * Math.exp(-Math.pow(deltaX * i, 2)); result += 2 * Math.exp(-Math.pow(deltaX * i, 2));
@ -183,33 +248,36 @@ public class SecureProtocol {
return (float) (result * deltaX / Math.pow(Math.PI, 0.5)); return (float) (result * deltaX / Math.pow(Math.PI, 0.5));
} }
private static double calculatePhi(double t) { private static double calculatePhi(double phiInput) {
return 0.5 * (1.0 + calculateErf((t / Math.sqrt(2.0)))); return 0.5 * (1.0 + calculateErf((phiInput / Math.sqrt(2.0))));
} }
private static double calculateBPositive(double eps, double s) { private static double calculateBPositive(double eps, double calInput) {
return calculatePhi(Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0))); return calculatePhi(Math.sqrt(eps * calInput)) -
Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
} }
private static double calculateBNegative(double eps, double s) { private static double calculateBNegative(double eps, double calInput) {
return calculatePhi(-Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0))); return calculatePhi(-Math.sqrt(eps * calInput)) -
Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
} }
private static double calculateSPositive(double eps, double targetDelta, double sInf, double sSup) { private static double calculateSPositive(double eps, double targetDelta, double initSInf, double initSSup) {
double deltaSup = calculateBPositive(eps, sSup); double deltaSup = calculateBPositive(eps, initSSup);
double sInf = initSInf;
double sSup = initSSup;
while (deltaSup <= targetDelta) { while (deltaSup <= targetDelta) {
sInf = sSup; sInf = sSup;
sSup = 2 * sInf; sSup = 2 * sInf;
deltaSup = calculateBPositive(eps, sSup); deltaSup = calculateBPositive(eps, sSup);
} }
double sMid = sInf + (sSup - sInf) / 2.0; double sMid = sInf + (sSup - sInf) / 2.0;
int iterMax = 1000; int iterMax = 1000;
int iters = 0; int iters = 0;
while (true) { while (true) {
double b = calculateBPositive(eps, sMid); double bPositive = calculateBPositive(eps, sMid);
if (b <= targetDelta) { if (bPositive <= targetDelta) {
if (targetDelta - b <= deltaError) { if (targetDelta - bPositive <= deltaError) {
break; break;
} else { } else {
sInf = sMid; sInf = sMid;
@ -226,8 +294,10 @@ public class SecureProtocol {
return sMid; return sMid;
} }
private static double calculateSNegative(double eps, double targetDelta, double sInf, double sSup) { private static double calculateSNegative(double eps, double targetDelta, double initSInf, double initSSup) {
double deltaSup = calculateBNegative(eps, sSup); double deltaSup = calculateBNegative(eps, initSSup);
double sInf = initSInf;
double sSup = initSSup;
while (deltaSup > targetDelta) { while (deltaSup > targetDelta) {
sInf = sSup; sInf = sSup;
sSup = 2 * sInf; sSup = 2 * sInf;
@ -238,9 +308,9 @@ public class SecureProtocol {
int iterMax = 1000; int iterMax = 1000;
int iters = 0; int iters = 0;
while (true) { while (true) {
double b = calculateBNegative(eps, sMid); double bNegative = calculateBNegative(eps, sMid);
if (b <= targetDelta) { if (bNegative <= targetDelta) {
if (targetDelta - b <= deltaError) { if (targetDelta - bNegative <= deltaError) {
break; break;
} else { } else {
sSup = sMid; sSup = sMid;
@ -259,17 +329,26 @@ public class SecureProtocol {
private static double calculateSigma(double clipNorm, double eps, double targetDelta) { private static double calculateSigma(double clipNorm, double eps, double targetDelta) {
double deltaZero = calculateBPositive(eps, 0); double deltaZero = calculateBPositive(eps, 0);
double alpha = 1; double alpha = 1d;
if (targetDelta > deltaZero) { if (targetDelta > deltaZero) {
double s = calculateSPositive(eps, targetDelta, 0, 1); double sPositive = calculateSPositive(eps, targetDelta, 0, 1);
alpha = Math.sqrt(1.0 + s / 2.0) - Math.sqrt(s / 2.0); alpha = Math.sqrt(1.0 + sPositive / 2.0) - Math.sqrt(sPositive / 2.0);
} else if (targetDelta < deltaZero) { } else if (targetDelta < deltaZero) {
double s = calculateSNegative(eps, targetDelta, 0, 1); double sNegative = calculateSNegative(eps, targetDelta, 0, 1);
alpha = Math.sqrt(1.0 + s / 2.0) + Math.sqrt(s / 2.0); alpha = Math.sqrt(1.0 + sNegative / 2.0) + Math.sqrt(sNegative / 2.0);
} else {
LOGGER.info(Common.addTag("[Encrypt] targetDelta = deltaZero"));
} }
return alpha * clipNorm / Math.sqrt(2.0 * eps); return alpha * clipNorm / Math.sqrt(2.0 * eps);
} }
/**
* Add differential privacy mask to model weights.
*
* @param builder the FlatBufferBuilder object used for serialization model weights.
* @param trainDataSize tne size of train data set.
* @return the serialized model weights after adding masks.
*/
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) { public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
// get feature map // get feature map
Map<String, float[]> map = new HashMap<String, float[]>(); Map<String, float[]> map = new HashMap<String, float[]>();
@ -279,6 +358,9 @@ public class SecureProtocol {
} else if (flParameter.getFlName().equals(LENET)) { } else if (flParameter.getFlName().equals(LENET)) {
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession())); map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
} else {
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
throw new IllegalArgumentException();
} }
Map<String, float[]> mapBeforeTrain = modelMap; Map<String, float[]> mapBeforeTrain = modelMap;
int featureSize = encryptFeatureName.size(); int featureSize = encryptFeatureName.size();
@ -286,19 +368,18 @@ public class SecureProtocol {
double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta); double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta);
LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "=============")); LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "============="));
// prepare gaussian noise
SecureRandom random = new SecureRandom();
int randomInt = random.nextInt();
Random r = new Random(randomInt);
// calculate l2-norm of all layers' update array // calculate l2-norm of all layers' update array
double updateL2Norm = 0; double updateL2Norm = 0d;
for (int i = 0; i < featureSize; i++) { for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i); String key = encryptFeatureName.get(i);
float[] data = map.get(key); float[] data = map.get(key);
float[] dataBeforeTrain = mapBeforeTrain.get(key); float[] dataBeforeTrain = mapBeforeTrain.get(key);
for (int j = 0; j < data.length; j++) { for (int j = 0; j < data.length; j++) {
float rawData = data[j]; float rawData = data[j];
if (j >= dataBeforeTrain.length) {
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
return new int[0];
}
float rawDataBeforeTrain = dataBeforeTrain[j]; float rawDataBeforeTrain = dataBeforeTrain[j];
float updateData = rawData - rawDataBeforeTrain; float updateData = rawData - rawDataBeforeTrain;
updateL2Norm += updateData * updateData; updateL2Norm += updateData * updateData;
@ -311,11 +392,26 @@ public class SecureProtocol {
int[] featuresMap = new int[featureSize]; int[] featuresMap = new int[featureSize];
for (int i = 0; i < featureSize; i++) { for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i); String key = encryptFeatureName.get(i);
if (!map.containsKey(key)) {
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
return new int[0];
}
float[] data = map.get(key); float[] data = map.get(key);
float[] data2 = new float[data.length]; float[] data2 = new float[data.length];
if (!mapBeforeTrain.containsKey(key)) {
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");
return new int[0];
}
float[] dataBeforeTrain = mapBeforeTrain.get(key); float[] dataBeforeTrain = mapBeforeTrain.get(key);
// prepare gaussian noise
SecureRandom secureRandom = Common.getSecureRandom();
for (int j = 0; j < data.length; j++) { for (int j = 0; j < data.length; j++) {
float rawData = data[j]; float rawData = data[j];
if (j >= dataBeforeTrain.length) {
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
return new int[0];
}
float rawDataBeforeTrain = dataBeforeTrain[j]; float rawDataBeforeTrain = dataBeforeTrain[j];
float updateData = rawData - rawDataBeforeTrain; float updateData = rawData - rawDataBeforeTrain;
@ -323,7 +419,7 @@ public class SecureProtocol {
updateData *= clipFactor; updateData *= clipFactor;
// add noise // add noise
double gaussianNoise = r.nextGaussian() * gaussianSigma; double gaussianNoise = secureRandom.nextGaussian() * gaussianSigma;
updateData += gaussianNoise; updateData += gaussianNoise;
data2[j] = rawDataBeforeTrain + updateData; data2[j] = rawDataBeforeTrain + updateData;
data2[j] = data2[j] * trainDataSize; data2[j] = data2[j] * trainDataSize;
@ -335,5 +431,4 @@ public class SecureProtocol {
} }
return featuresMap; return featuresMap;
} }
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,8 +13,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
/**
* The training mode of federated learning.
*
* @since 2021-06-30
*/
public enum ServerMod { public enum ServerMod {
FEDERATED_LEARNING, FEDERATED_LEARNING,
HYBRID_TRAINING HYBRID_TRAINING

View File

@ -1,6 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
@ -13,13 +12,20 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlInferBert; import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet; import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FLPlan;
import mindspore.schema.FeatureMap; import mindspore.schema.FeatureMap;
import mindspore.schema.RequestFLJob; import mindspore.schema.RequestFLJob;
import mindspore.schema.ResponseCode; import mindspore.schema.ResponseCode;
@ -28,15 +34,28 @@ import mindspore.schema.ResponseFLJob;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT; /**
import static com.mindspore.flclient.LocalFLParameter.LENET; * StartFLJob
*
* @since 2021-08-25
*/
public class StartFLJob { public class StartFLJob {
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString());
private static volatile StartFLJob startFLJob;
static { static {
System.loadLibrary("mindspore-lite-jni"); System.loadLibrary("mindspore-lite-jni");
} }
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString()); private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int featureSize;
private String nextRequestTime;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private StartFLJob() {
}
class RequestStartFLJobBuilder { class RequestStartFLJobBuilder {
private RequestFLJob requestFLJob; private RequestFLJob requestFLJob;
@ -51,21 +70,53 @@ public class StartFLJob {
builder = new FlatBufferBuilder(); builder = new FlatBufferBuilder();
} }
/**
* set flName
*
* @param name String
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder flName(String name) { public RequestStartFLJobBuilder flName(String name) {
if (name == null || name.isEmpty()) {
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <name> is null or empty, please check!"));
throw new IllegalArgumentException();
}
this.nameOffset = this.builder.createString(name); this.nameOffset = this.builder.createString(name);
return this; return this;
} }
/**
* set id
*
* @param id String
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder id(String id) { public RequestStartFLJobBuilder id(String id) {
if (id == null || id.isEmpty()) {
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <id> is null or empty, please check!"));
throw new IllegalArgumentException();
}
this.idOffset = this.builder.createString(id); this.idOffset = this.builder.createString(id);
return this; return this;
} }
/**
* set time
*
* @param timestamp long
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder time(long timestamp) { public RequestStartFLJobBuilder time(long timestamp) {
this.timestampOffset = builder.createString(String.valueOf(timestamp)); this.timestampOffset = builder.createString(String.valueOf(timestamp));
return this; return this;
} }
/**
* set dataSize
*
* @param dataSize int
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder dataSize(int dataSize) { public RequestStartFLJobBuilder dataSize(int dataSize) {
// temp code need confirm // temp code need confirm
this.dataSize = dataSize; this.dataSize = dataSize;
@ -73,11 +124,22 @@ public class StartFLJob {
return this; return this;
} }
/**
* set iteration
*
* @param iteration iteration
* @return RequestStartFLJobBuilder
*/
public RequestStartFLJobBuilder iteration(int iteration) { public RequestStartFLJobBuilder iteration(int iteration) {
this.iteration = iteration; this.iteration = iteration;
return this; return this;
} }
/**
* build protobuffer
*
* @return byte[] data
*/
public byte[] build() { public byte[] build() {
int root = RequestFLJob.createRequestFLJob(this.builder, this.nameOffset, this.idOffset, this.iteration, int root = RequestFLJob.createRequestFLJob(this.builder, this.nameOffset, this.idOffset, this.iteration,
this.dataSize, this.timestampOffset); this.dataSize, this.timestampOffset);
@ -86,20 +148,11 @@ public class StartFLJob {
} }
} }
private static volatile StartFLJob startFLJob; /**
* getInstance of StartFLJob
private FLClientStatus status; *
* @return StartFLJob instance
private FLParameter flParameter = FLParameter.getInstance(); */
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int featureSize;
private String nextRequestTime;
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
private StartFLJob() {
}
public static StartFLJob getInstance() { public static StartFLJob getInstance() {
StartFLJob localRef = startFLJob; StartFLJob localRef = startFLJob;
if (localRef == null) { if (localRef == null) {
@ -117,6 +170,14 @@ public class StartFLJob {
return nextRequestTime; return nextRequestTime;
} }
/**
* get request start FLJob
*
* @param dataSize dataSize
* @param iteration iteration
* @param time time
* @return byte[] data
*/
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) { public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) {
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder(); RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
return builder.flName(flParameter.getFlName()) return builder.flName(flParameter.getFlName())
@ -135,6 +196,7 @@ public class StartFLJob {
return encryptFeatureName; return encryptFeatureName;
} }
private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) { private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) {
int fmCount = flJob.featureMapLength(); int fmCount = flJob.featureMapLength();
encryptFeatureName.clear(); encryptFeatureName.clear();
@ -149,6 +211,10 @@ public class StartFLJob {
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>(); ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) { for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i); FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname(); String featureName = feature.weightFullname();
if (localFLParameter.getAlbertWeightName().contains(featureName)) { if (localFLParameter.getAlbertWeightName().contains(featureName)) {
albertFeatureMaps.add(feature); albertFeatureMaps.add(feature);
@ -160,19 +226,23 @@ public class StartFLJob {
} else { } else {
continue; continue;
} }
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
"weightLength: " + feature.dataLength()));
} }
int tag = 0; int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------")); LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
"model-----------------"));
AlInferBert alInferBert = AlInferBert.getInstance(); AlInferBert alInferBert = AlInferBert.getInstance();
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps); tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
inferFeatureMaps);
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>")); LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------")); LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps); tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
albertFeatureMaps);
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>")); LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
@ -182,16 +252,22 @@ public class StartFLJob {
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>(); ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
for (int i = 0; i < fmCount; i++) { for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i); FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname(); String featureName = feature.weightFullname();
featureMaps.add(feature); featureMaps.add(feature);
featureSize += feature.dataLength(); featureSize += feature.dataLength();
encryptFeatureName.add(featureName); encryptFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
"weightLength: " + feature.dataLength()));
} }
int tag = 0; int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------")); LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps); tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
featureMaps);
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>")); LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
@ -206,11 +282,16 @@ public class StartFLJob {
encryptFeatureName.clear(); encryptFeatureName.clear();
for (int i = 0; i < fmCount; i++) { for (int i = 0; i < fmCount; i++) {
FeatureMap feature = flJob.featureMap(i); FeatureMap feature = flJob.featureMap(i);
if (feature == null) {
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
return FLClientStatus.FAILED;
}
String featureName = feature.weightFullname(); String featureName = feature.weightFullname();
featureMaps.add(feature); featureMaps.add(feature);
featureSize += feature.dataLength(); featureSize += feature.dataLength();
encryptFeatureName.add(featureName); encryptFeatureName.add(featureName);
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength())); LOGGER.info(Common.addTag("[startFLJob] weightFullname: " +
feature.weightFullname() + ", weightLength: " + feature.dataLength()));
} }
int tag = 0; int tag = 0;
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------")); LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
@ -223,7 +304,22 @@ public class StartFLJob {
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
} }
/**
* response res
*
* @param flJob ResponseFLJob
* @return FLClientStatus
*/
public FLClientStatus doResponse(ResponseFLJob flJob) { public FLClientStatus doResponse(ResponseFLJob flJob) {
if (flJob == null) {
LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null"));
return FLClientStatus.FAILED;
}
FLPlan flPlanConfig = flJob.flPlanConfig();
if (flPlanConfig == null) {
LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null"));
return FLClientStatus.FAILED;
}
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode())); LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode()));
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason())); LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration())); LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
@ -236,11 +332,12 @@ public class StartFLJob {
switch (retCode) { switch (retCode) {
case (ResponseCode.SUCCEED): case (ResponseCode.SUCCEED):
localFLParameter.setServerMod(flJob.flPlanConfig().serverMode()); localFLParameter.setServerMod(flPlanConfig.serverMode());
if (ALBERT.equals(flParameter.getFlName())) { if (ALBERT.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>")); LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
status = parseResponseAlbert(flJob); status = parseResponseAlbert(flJob);
} else if (LENET.equals(flParameter.getFlName())) { }
if (LENET.equals(flParameter.getFlName())) {
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>")); LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
status = parseResponseLenet(flJob); status = parseResponseLenet(flJob);
} }
@ -256,8 +353,4 @@ public class StartFLJob {
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
public FLClientStatus getStatus() {
return this.status;
}
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -13,35 +13,49 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient; package com.mindspore.flclient;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.ResponseGetModel;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME; import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import static com.mindspore.flclient.LocalFLParameter.ALBERT; import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET; import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.mindspore.flclient.model.AlInferBert;
import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.ResponseGetModel;
import java.nio.ByteBuffer;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
/**
* SyncFLJob defines the APIs for federated learning task.
* API flJobRun() for starting federated learning on the device, the API modelInference() for inference on the
* device, and the API getModel() for obtaining the latest model on the cloud.
*
* @since 2021-06-30
*/
public class SyncFLJob { public class SyncFLJob {
private static final Logger LOGGER = Logger.getLogger(SyncFLJob.class.toString()); private static final Logger LOGGER = Logger.getLogger(SyncFLJob.class.toString());
private FLParameter flParameter = FLParameter.getInstance(); private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback(); private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
private Map<String, float[]> oldFeatureMap; private Map<String, float[]> oldFeatureMap;
public SyncFLJob() { /**
} * Starts a federated learning task on the device.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus flJobRun() { public FLClientStatus flJobRun() {
Common.setSecureRandom(new SecureRandom());
localFLParameter.setFlID(flParameter.getClientID()); localFLParameter.setFlID(flParameter.getClientID());
FLLiteClient client = new FLLiteClient(); FLLiteClient client = new FLLiteClient();
FLClientStatus curStatus; FLClientStatus curStatus;
@ -58,17 +72,14 @@ public class SyncFLJob {
if (trainDataSize <= 0) { if (trainDataSize <= 0) {
LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0")); LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0"));
curStatus = FLClientStatus.FAILED; curStatus = FLClientStatus.FAILED;
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode()); flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
client.getRetCode());
break; break;
} }
client.setTrainDataSize(trainDataSize); client.setTrainDataSize(trainDataSize);
// startFLJob // startFLJob
curStatus = client.startFLJob(); curStatus = startFLJob(client);
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.startFLJob();
}
if (curStatus == FLClientStatus.RESTART) { if (curStatus == FLClientStatus.RESTART) {
restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue; continue;
@ -100,11 +111,7 @@ public class SyncFLJob {
LOGGER.info(Common.addTag("[train] train succeed")); LOGGER.info(Common.addTag("[train] train succeed"));
// updateModel // updateModel
curStatus = client.updateModel(); curStatus = updateModel(client);
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.updateModel();
}
if (curStatus == FLClientStatus.RESTART) { if (curStatus == FLClientStatus.RESTART) {
restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue; continue;
@ -124,11 +131,7 @@ public class SyncFLJob {
} }
// getModel // getModel
curStatus = client.getModel(); curStatus = getModel(client);
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.getModel();
}
if (curStatus == FLClientStatus.RESTART) { if (curStatus == FLClientStatus.RESTART) {
restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode()); restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
continue; continue;
@ -142,7 +145,8 @@ public class SyncFLJob {
// evaluate model after getting model from server // evaluate model after getting model from server
if (flParameter.getTestDataset().equals("null")) { if (flParameter.getTestDataset().equals("null")) {
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the combine model")); LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the model after getting" +
" model from server"));
} else { } else {
curStatus = client.evaluateModel(); curStatus = client.evaluateModel();
if (curStatus == FLClientStatus.FAILED) { if (curStatus == FLClientStatus.FAILED) {
@ -151,33 +155,68 @@ public class SyncFLJob {
} }
LOGGER.info(Common.addTag("[evaluate] evaluate succeed")); LOGGER.info(Common.addTag("[evaluate] evaluate succeed"));
} }
LOGGER.info(Common.addTag("========================================================the total response of " + client.getIteration() + ": " + curStatus + "======================================================================")); LOGGER.info(Common.addTag("========================================================the total response of "
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode()); + client.getIteration() + ": " + curStatus +
"======================================================================"));
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
client.getRetCode());
} while (client.getIteration() < client.getIterations()); } while (client.getIteration() < client.getIterations());
client.finalize(); client.freeSession();
LOGGER.info(Common.addTag("flJobRun finish")); LOGGER.info(Common.addTag("flJobRun finish"));
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode()); flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
return curStatus; return curStatus;
} }
private FLClientStatus startFLJob(FLLiteClient client) {
FLClientStatus curStatus = client.startFLJob();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.startFLJob();
}
return curStatus;
}
private FLClientStatus updateModel(FLLiteClient client) {
FLClientStatus curStatus = client.updateModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.updateModel();
}
return curStatus;
}
private FLClientStatus getModel(FLLiteClient client) {
FLClientStatus curStatus = client.getModel();
while (curStatus == FLClientStatus.WAIT) {
waitSomeTime();
curStatus = client.getModel();
}
return curStatus;
}
private void updateDpNormClip(FLLiteClient client) { private void updateDpNormClip(FLLiteClient client) {
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel(); EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
if (encryptLevel == EncryptLevel.DP_ENCRYPT) { if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
int currentIter = client.getIteration(); int currentIter = client.getIteration();
Map<String, float[]> fedFeatureMap = getFeatureMap(); Map<String, float[]> fedFeatureMap = getFeatureMap();
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap); float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
if (fedWeightUpdateNorm == -1) {
LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " +
"-1, please check!"));
throw new IllegalArgumentException();
}
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm)); LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
float newNormCLip = (float) client.dpNormClipFactor * fedWeightUpdateNorm; float newNormCLip = (float) client.getDpNormClipFactor() * fedWeightUpdateNorm;
if (currentIter == 1) { if (currentIter == 1) {
client.dpNormClipAdapt = newNormCLip; client.setDpNormClipAdapt(newNormCLip);
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated.")); LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
} else { } else {
if (newNormCLip < client.dpNormClipAdapt) { if (newNormCLip < client.getDpNormClipAdapt()) {
client.dpNormClipAdapt = newNormCLip; client.setDpNormClipAdapt(newNormCLip);
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated.")); LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
} }
} }
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.dpNormClipAdapt)); LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.getDpNormClipAdapt()));
} }
} }
@ -190,11 +229,16 @@ public class SyncFLJob {
} }
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) { private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
float updateL2Norm = 0; float updateL2Norm = 0f;
for (String key : originalData.keySet()) { for (String key : originalData.keySet()) {
float[] data = originalData.get(key); float[] data = originalData.get(key);
float[] dataAfterUpdate = newData.get(key); float[] dataAfterUpdate = newData.get(key);
for (int j = 0; j < data.length; j++) { for (int j = 0; j < data.length; j++) {
if (j >= dataAfterUpdate.length) {
LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " +
"please check");
return -1;
}
float updateData = data[j] - dataAfterUpdate[j]; float updateData = data[j] - dataAfterUpdate[j];
updateL2Norm += updateData * updateData; updateL2Norm += updateData * updateData;
} }
@ -215,12 +259,22 @@ public class SyncFLJob {
return featureMap; return featureMap;
} }
/**
* Starts an inference task on the device.
*
* @return the status code corresponding to the response message.
*/
public int[] modelInference() { public int[] modelInference() {
int[] labels = new int[0]; int[] labels = new int[0];
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
AlInferBert alInferBert = AlInferBert.getInstance(); AlInferBert alInferBert = AlInferBert.getInstance();
LOGGER.info(Common.addTag("===========model inference=============")); LOGGER.info(Common.addTag("===========model inference============="));
labels = alInferBert.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile()); labels = alInferBert.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset(),
flParameter.getVocabFile(), flParameter.getIdsFile());
if (labels == null || labels.length == 0) {
LOGGER.severe("[model inference] the returned label from adInferBert.inferModel() is null, please " +
"check");
}
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels))); LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
SessionUtil.free(alInferBert.getTrainSession()); SessionUtil.free(alInferBert.getTrainSession());
LOGGER.info(Common.addTag("[model inference] inference finish")); LOGGER.info(Common.addTag("[model inference] inference finish"));
@ -228,49 +282,62 @@ public class SyncFLJob {
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
LOGGER.info(Common.addTag("===========model inference=============")); LOGGER.info(Common.addTag("===========model inference============="));
labels = trainLenet.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset().split(",")[0]); labels = trainLenet.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset().split(",")[0]);
if (labels == null || labels.length == 0) {
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
}
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels))); LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
SessionUtil.free(trainLenet.getTrainSession()); SessionUtil.free(trainLenet.getTrainSession());
LOGGER.info(Common.addTag("[model inference] inference finish")); LOGGER.info(Common.addTag("[model inference] inference finish"));
} }
if (labels.length == 0) {
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
}
return labels; return labels;
} }
/**
* Obtains the latest model on the cloud.
*
* @return the status code corresponding to the response message.
*/
public FLClientStatus getModel() { public FLClientStatus getModel() {
Common.setSecureRandom(Common.getFastSecureRandom());
int tag = 0; int tag = 0;
FLClientStatus status = FLClientStatus.SUCCESS; FLClientStatus status;
try { try {
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString()); localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
flParameter.getTrainModelPath() + " Create Train Session============="));
AlTrainBert alTrainBert = AlTrainBert.getInstance(); AlTrainBert alTrainBert = AlTrainBert.getInstance();
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true); tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1")); LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the " +
"return is -1"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session=============")); LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " +
flParameter.getInferModelPath() + " Create inference Session============="));
AlInferBert alInferBert = AlInferBert.getInstance(); AlInferBert alInferBert = AlInferBert.getInstance();
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false); tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
} else if (flParameter.getFlName().equals(LENET)) { } else if (flParameter.getFlName().equals(LENET)) {
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString()); localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session=============")); LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
flParameter.getTrainModelPath() + " Create Train Session============="));
TrainLenet trainLenet = TrainLenet.getInstance(); TrainLenet trainLenet = TrainLenet.getInstance();
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true); tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
} }
if (tag == -1) { if (tag == -1) {
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1")); LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
"is -1"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
FLCommunication flCommunication = FLCommunication.getInstance(); FLCommunication flCommunication = FLCommunication.getInstance();
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
flParameter.getDomainName());
GetModel getModelBuf = GetModel.getInstance(); GetModel getModelBuf = GetModel.getInstance();
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0); byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer); byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(message)) {
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
"again"));
status = FLClientStatus.WAIT; status = FLClientStatus.WAIT;
return status; return status;
} }
@ -279,8 +346,8 @@ public class SyncFLJob {
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer); ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
status = getModelBuf.doResponse(responseDataBuf); status = getModelBuf.doResponse(responseDataBuf);
LOGGER.info(Common.addTag("[getModel] success!")); LOGGER.info(Common.addTag("[getModel] success!"));
} catch (Exception e) { } catch (Exception ex) {
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage())); LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + ex.getMessage()));
status = FLClientStatus.FAILED; status = FLClientStatus.FAILED;
} }
if (flParameter.getFlName().equals(ALBERT)) { if (flParameter.getFlName().equals(ALBERT)) {
@ -299,19 +366,16 @@ public class SyncFLJob {
} }
private void waitSomeTime() { private void waitSomeTime() {
if (flParameter.getSleepTime() != 0) if (flParameter.getSleepTime() != 0) {
Common.sleep(flParameter.getSleepTime()); Common.sleep(flParameter.getSleepTime());
else } else {
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
}
} }
private void waitNextReqTime(String nextReqTime) { private void waitNextReqTime(String nextReqTime) {
if (flParameter.isTimer()) { long waitTime = Common.getWaitTime(nextReqTime);
long waitTime = Common.getWaitTime(nextReqTime); Common.sleep(waitTime);
Common.sleep(waitTime);
} else {
waitSomeTime();
}
} }
private void restart(String tag, String nextReqTime, int iteration, int retcode) { private void restart(String tag, String nextReqTime, int iteration, int retcode) {
@ -322,7 +386,8 @@ public class SyncFLJob {
private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) { private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) {
LOGGER.info(Common.addTag(tag + " failed")); LOGGER.info(Common.addTag(tag + " failed"));
LOGGER.info(Common.addTag("========================================================the total response of " + iteration + ": " + curStatus + "======================================================================")); LOGGER.info(Common.addTag("=========================================the total response of " +
iteration + ": " + curStatus + "========================================="));
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode); flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
} }
@ -334,53 +399,27 @@ public class SyncFLJob {
String flName = args[4]; String flName = args[4];
String trainModelPath = args[5]; String trainModelPath = args[5];
String inferModelPath = args[6]; String inferModelPath = args[6];
String clientID = args[7]; boolean useSSL = Boolean.parseBoolean(args[7]);
String ip = args[8]; String domainName = args[8];
boolean useSSL = Boolean.parseBoolean(args[9]); boolean useElb = Boolean.parseBoolean(args[9]);
int port = Integer.parseInt(args[10]); int serverNum = Integer.parseInt(args[10]);
int timeWindow = Integer.parseInt(args[11]); String certPath = args[11];
boolean useElb = Boolean.parseBoolean(args[12]); String task = args[12];
int serverNum = Integer.parseInt(args[13]);
boolean useHttps = Boolean.parseBoolean(args[14]);
String certPath = args[15];
String task = args[16];
FLParameter flParameter = FLParameter.getInstance(); FLParameter flParameter = FLParameter.getInstance();
LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset));
LOGGER.info(Common.addTag("[args] vocabFile: " + vocabFile));
LOGGER.info(Common.addTag("[args] idsFile: " + idsFile));
LOGGER.info(Common.addTag("[args] testDataset: " + testDataset));
LOGGER.info(Common.addTag("[args] flName: " + flName));
LOGGER.info(Common.addTag("[args] trainModelPath: " + trainModelPath));
LOGGER.info(Common.addTag("[args] inferModelPath: " + inferModelPath));
LOGGER.info(Common.addTag("[args] clientID: " + clientID));
LOGGER.info(Common.addTag("[args] ip: " + ip));
LOGGER.info(Common.addTag("[args] useSSL: " + useSSL));
LOGGER.info(Common.addTag("[args] port: " + port));
LOGGER.info(Common.addTag("[args] timeWindow: " + timeWindow));
LOGGER.info(Common.addTag("[args] useElb: " + useElb));
LOGGER.info(Common.addTag("[args] serverNum: " + serverNum));
LOGGER.info(Common.addTag("[args] useHttps: " + useHttps));
LOGGER.info(Common.addTag("[args] certPath: " + certPath));
LOGGER.info(Common.addTag("[args] task: " + task));
flParameter.setClientID(clientID);
SyncFLJob syncFLJob = new SyncFLJob(); SyncFLJob syncFLJob = new SyncFLJob();
if (task.equals("train")) { if (task.equals("train")) {
flParameter.setUseHttps(useHttps);
if (useSSL) { if (useSSL) {
flParameter.setCertPath(certPath); flParameter.setCertPath(certPath);
} }
flParameter.setHostName(ip);
flParameter.setTrainDataset(trainDataset); flParameter.setTrainDataset(trainDataset);
flParameter.setFlName(flName); flParameter.setFlName(flName);
flParameter.setTrainModelPath(trainModelPath); flParameter.setTrainModelPath(trainModelPath);
flParameter.setTestDataset(testDataset); flParameter.setTestDataset(testDataset);
flParameter.setInferModelPath(inferModelPath); flParameter.setInferModelPath(inferModelPath);
flParameter.setIp(ip);
flParameter.setUseSSL(useSSL); flParameter.setUseSSL(useSSL);
flParameter.setPort(port); flParameter.setDomainName(domainName);
flParameter.setTimeWindow(timeWindow);
flParameter.setUseElb(useElb); flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum); flParameter.setServerNum(serverNum);
if (ALBERT.equals(flName)) { if (ALBERT.equals(flName)) {
@ -398,17 +437,14 @@ public class SyncFLJob {
} }
syncFLJob.modelInference(); syncFLJob.modelInference();
} else if (task.equals("getModel")) { } else if (task.equals("getModel")) {
flParameter.setUseHttps(useHttps);
if (useSSL) { if (useSSL) {
flParameter.setCertPath(certPath); flParameter.setCertPath(certPath);
} }
flParameter.setHostName(ip);
flParameter.setFlName(flName); flParameter.setFlName(flName);
flParameter.setTrainModelPath(trainModelPath); flParameter.setTrainModelPath(trainModelPath);
flParameter.setInferModelPath(inferModelPath); flParameter.setInferModelPath(inferModelPath);
flParameter.setIp(ip);
flParameter.setUseSSL(useSSL); flParameter.setUseSSL(useSSL);
flParameter.setPort(port); flParameter.setDomainName(domainName);
flParameter.setUseElb(useElb); flParameter.setUseElb(useElb);
flParameter.setServerNum(serverNum); flParameter.setServerNum(serverNum);
syncFLJob.getModel(); syncFLJob.getModel();

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,10 +16,15 @@
package com.mindspore.flclient; package com.mindspore.flclient;
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
import static com.mindspore.flclient.LocalFLParameter.LENET;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.model.AlTrainBert; import com.mindspore.flclient.model.AlTrainBert;
import com.mindspore.flclient.model.SessionUtil; import com.mindspore.flclient.model.SessionUtil;
import com.mindspore.flclient.model.TrainLenet; import com.mindspore.flclient.model.TrainLenet;
import mindspore.schema.FeatureMap; import mindspore.schema.FeatureMap;
import mindspore.schema.RequestUpdateModel; import mindspore.schema.RequestUpdateModel;
import mindspore.schema.ResponseCode; import mindspore.schema.ResponseCode;
@ -31,145 +36,31 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.LocalFLParameter.ALBERT; /**
import static com.mindspore.flclient.LocalFLParameter.LENET; * Define the serialization method, handle the response message returned from server for updateModel request.
*
* @since 2021-06-30
*/
public class UpdateModel { public class UpdateModel {
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
private static volatile UpdateModel updateModel;
static { static {
System.loadLibrary("mindspore-lite-jni"); System.loadLibrary("mindspore-lite-jni");
} }
class RequestUpdateModelBuilder {
private RequestUpdateModel requestUM;
private FlatBufferBuilder builder;
private int fmOffset = 0;
private int nameOffset = 0;
private int idOffset = 0;
private int timestampOffset = 0;
private int iteration = 0;
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
public RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
builder = new FlatBufferBuilder();
this.encryptLevel = encryptLevel;
}
public RequestUpdateModelBuilder flName(String name) {
this.nameOffset = this.builder.createString(name);
return this;
}
public RequestUpdateModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timestampOffset = builder.createString(String.valueOf(time));
return this;
}
public RequestUpdateModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
public RequestUpdateModelBuilder id(String id) {
this.idOffset = this.builder.createString(id);
return this;
}
public RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
switch (encryptLevel) {
case PW_ENCRYPT:
try {
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is null, please check");
throw new RuntimeException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
return this;
} catch (Exception e) {
LOGGER.severe("[Encrypt] catch error in maskModel: " + e.getMessage());
throw new RuntimeException();
}
case DP_ENCRYPT:
try {
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is null, please check");
throw new RuntimeException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
return this;
} catch (Exception e) {
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
throw new RuntimeException();
}
case NOT_ENCRYPT:
default:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
}
int featureSize = encryptFeatureName.size();
int[] fmOffsets = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature size: " + data.length));
for (int j = 0; j < data.length; j++) {
float rawData = data[j];
data[j] = data[j] * trainDataSize;
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
fmOffsets[i] = featureMap;
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
return this;
}
}
public byte[] build() {
RequestUpdateModel.startRequestUpdateModel(this.builder);
RequestUpdateModel.addFlName(builder, nameOffset);
RequestUpdateModel.addFlId(this.builder, idOffset);
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
RequestUpdateModel.addIteration(builder, this.iteration);
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
int root = RequestUpdateModel.endRequestUpdateModel(builder);
builder.finish(root);
return builder.sizedByteArray();
}
}
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
private FLParameter flParameter = FLParameter.getInstance(); private FLParameter flParameter = FLParameter.getInstance();
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private String nextRequestTime;
private FLClientStatus status; private FLClientStatus status;
private static volatile UpdateModel updateModel;
private UpdateModel() { private UpdateModel() {
} }
/**
* Get the singleton object of the class UpdateModel.
*
* @return the singleton object of the class UpdateModel.
*/
public static UpdateModel getInstance() { public static UpdateModel getInstance() {
UpdateModel localRef = updateModel; UpdateModel localRef = updateModel;
if (localRef == null) { if (localRef == null) {
@ -183,25 +74,35 @@ public class UpdateModel {
return localRef; return localRef;
} }
public String getNextRequestTime() {
return nextRequestTime;
}
public FLClientStatus getStatus() { public FLClientStatus getStatus() {
return status; return status;
} }
/**
* Get a flatBuffer builder of RequestUpdateModel.
*
* @param iteration current iteration of federated learning task.
* @param secureProtocol the object that defines encryption and decryption methods.
* @param trainDataSize the size of train date set.
* @return the flatBuffer builder of RequestUpdateModel in byte[] format.
*/
public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) { public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) {
RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel()); RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel());
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID()).featuresMap(secureProtocol, trainDataSize).iteration(iteration).build(); return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID())
.featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
} }
/**
* Handle the response message returned from server.
*
* @param response the response message returned from server.
* @return the status code corresponding to the response message.
*/
public FLClientStatus doResponse(ResponseUpdateModel response) { public FLClientStatus doResponse(ResponseUpdateModel response) {
LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================")); LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================"));
LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode())); LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode()));
LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason())); LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason()));
LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime())); LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime()));
nextRequestTime = response.nextReqTime();
switch (response.retcode()) { switch (response.retcode()) {
case (ResponseCode.SUCCEED): case (ResponseCode.SUCCEED):
LOGGER.info(Common.addTag("[updateModel] updateModel success")); LOGGER.info(Common.addTag("[updateModel] updateModel success"));
@ -213,8 +114,165 @@ public class UpdateModel {
LOGGER.warning(Common.addTag("[updateModel] catch RequestError or SystemError")); LOGGER.warning(Common.addTag("[updateModel] catch RequestError or SystemError"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
default: default:
LOGGER.severe(Common.addTag("[updateModel]the return <retcode> from server is invalid: " + response.retcode())); LOGGER.severe(Common.addTag("[updateModel]the return <retCode> from server is invalid: " +
response.retcode()));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
class RequestUpdateModelBuilder {
private RequestUpdateModel requestUM;
private FlatBufferBuilder builder;
private int fmOffset = 0;
private int nameOffset = 0;
private int idOffset = 0;
private int timestampOffset = 0;
private int iteration = 0;
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
private RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
builder = new FlatBufferBuilder();
this.encryptLevel = encryptLevel;
}
/**
* Serialize the element flName in RequestUpdateModel.
*
* @param name the model name.
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder flName(String name) {
if (name == null || name.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the parameter of <name> is null or empty, please check!"));
throw new IllegalArgumentException();
}
this.nameOffset = this.builder.createString(name);
return this;
}
/**
* Serialize the element timestamp in RequestUpdateModel.
*
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder time() {
Date date = new Date();
long time = date.getTime();
this.timestampOffset = builder.createString(String.valueOf(time));
return this;
}
/**
* Serialize the element iteration in RequestUpdateModel.
*
* @param iteration current iteration of federated learning task.
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder iteration(int iteration) {
this.iteration = iteration;
return this;
}
/**
* Serialize the element fl_id in RequestUpdateModel.
*
* @param id a number that uniquely identifies a client.
* @return the RequestUpdateModelBuilder object.
*/
private RequestUpdateModelBuilder id(String id) {
if (id == null || id.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the parameter of <id> is null or empty, please check!"));
throw new IllegalArgumentException();
}
this.idOffset = this.builder.createString(id);
return this;
}
private RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
switch (encryptLevel) {
case PW_ENCRYPT:
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " +
"null, please check");
throw new IllegalArgumentException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
return this;
case DP_ENCRYPT:
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " +
"null, please check");
throw new IllegalArgumentException();
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
return this;
case NOT_ENCRYPT:
default:
Map<String, float[]> map = new HashMap<String, float[]>();
if (flParameter.getFlName().equals(ALBERT)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
flParameter.getFlName()));
AlTrainBert alTrainBert = AlTrainBert.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
".convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
} else if (flParameter.getFlName().equals(LENET)) {
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
flParameter.getFlName()));
TrainLenet trainLenet = TrainLenet.getInstance();
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
if (map.isEmpty()) {
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
".convertTensorToFeatures>"));
status = FLClientStatus.FAILED;
}
} else {
LOGGER.severe(Common.addTag("[updateModel] the flName is not valid"));
throw new IllegalArgumentException();
}
int featureSize = encryptFeatureName.size();
int[] fmOffsets = new int[featureSize];
for (int i = 0; i < featureSize; i++) {
String key = encryptFeatureName.get(i);
float[] data = map.get(key);
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
"size: " + data.length));
for (int j = 0; j < data.length; j++) {
data[j] = data[j] * trainDataSize;
}
int featureName = builder.createString(key);
int weight = FeatureMap.createDataVector(builder, data);
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
fmOffsets[i] = featureMap;
}
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
return this;
}
}
/**
* Create a flatBuffer builder of RequestUpdateModel.
*
* @return the flatBuffer builder of RequestUpdateModel in byte[] format.
*/
private byte[] build() {
RequestUpdateModel.startRequestUpdateModel(this.builder);
RequestUpdateModel.addFlName(builder, nameOffset);
RequestUpdateModel.addFlId(this.builder, idOffset);
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
RequestUpdateModel.addIteration(builder, this.iteration);
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
int root = RequestUpdateModel.endRequestUpdateModel(builder);
builder.finish(root);
return builder.sizedByteArray();
}
}
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,21 +16,33 @@
package com.mindspore.flclient.cipher; package com.mindspore.flclient.cipher;
import static com.mindspore.flclient.LocalFLParameter.I_VEC_LEN;
import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
import com.mindspore.flclient.Common; import com.mindspore.flclient.Common;
import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.logging.Logger; import java.util.logging.Logger;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
/**
* Define encryption and decryption methods.
*
* @since 2021-06-30
*/
public class AESEncrypt { public class AESEncrypt {
private static final Logger LOGGER = Logger.getLogger(AESEncrypt.class.toString()); private static final Logger LOGGER = Logger.getLogger(AESEncrypt.class.toString());
/**
* 128, 192 or 256
*/
private static final int KEY_SIZE = 256;
private static final int I_VEC_LEN = 16;
/** /**
* encrypt/decrypt algorithm name * encrypt/decrypt algorithm name
@ -43,62 +55,145 @@ public class AESEncrypt {
private static final String CIPHER_MODE_CTR = "AES/CTR/NoPadding"; private static final String CIPHER_MODE_CTR = "AES/CTR/NoPadding";
private static final String CIPHER_MODE_CBC = "AES/CBC/PKCS5PADDING"; private static final String CIPHER_MODE_CBC = "AES/CBC/PKCS5PADDING";
private String CIPHER_MODE; private String cipherMod;
private static final int RANDOM_LEN = KEY_SIZE / 8; /**
* Defining a Constructor of the class AESEncrypt.
private String iVecS = "1111111111111111"; *
private byte[] iVec = iVecS.getBytes("utf-8"); * @param key the Key.
* @param mode the encryption Mode.
public AESEncrypt(byte[] key, byte[] iVecIn, String mode) throws UnsupportedEncodingException { */
public AESEncrypt(byte[] key, String mode) {
if (key == null) { if (key == null) {
LOGGER.severe(Common.addTag("Key is null")); LOGGER.severe(Common.addTag("Key is null"));
return; return;
} }
if (key.length != KEY_SIZE / 8) { if (key.length != KEY_LEN) {
LOGGER.severe(Common.addTag("the length of key is not correct")); LOGGER.severe(Common.addTag("the length of key is not correct"));
return; return;
} }
if (mode.contains("CBC")) { if (mode.contains("CBC")) {
CIPHER_MODE = CIPHER_MODE_CBC; cipherMod = CIPHER_MODE_CBC;
} else if (mode.contains("CTR")) { } else if (mode.contains("CTR")) {
CIPHER_MODE = CIPHER_MODE_CTR; cipherMod = CIPHER_MODE_CTR;
} else { } else {
return; return;
} }
if (iVecIn == null || iVecIn.length != I_VEC_LEN) { }
return;
/**
* Defining the CBC encryption Mode.
*
* @param key the Key.
* @param data the data to be encrypted.
* @return the data to be encrypted.
*/
public byte[] encrypt(byte[] key, byte[] data) {
if (key == null) {
LOGGER.severe(Common.addTag("Key is null"));
return new byte[0];
}
if (data == null) {
LOGGER.severe(Common.addTag("data is null"));
return new byte[0];
}
try {
byte[] iVec = new byte[I_VEC_LEN];
SecureRandom secureRandom = Common.getSecureRandom();
secureRandom.nextBytes(iVec);
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(cipherMod);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
byte[] encrypted = cipher.doFinal(data);
byte[] encryptedAddIv = new byte[encrypted.length + iVec.length];
System.arraycopy(iVec, 0, encryptedAddIv, 0, iVec.length);
System.arraycopy(encrypted, 0, encryptedAddIv, iVec.length, encrypted.length);
return encryptedAddIv;
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException or " +
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
return new byte[0];
} }
iVec = iVecIn;
} }
public byte[] encrypt(byte[] key, byte[] data) throws Exception { /**
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM); * Defining the CTR encryption Mode.
Cipher cipher = Cipher.getInstance(CIPHER_MODE); *
IvParameterSpec iv = new IvParameterSpec(iVec); * @param key the Key.
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv); * @param data the data to be encrypted.
byte[] encrypted = cipher.doFinal(data); * @param iVec the IV value.
String encryptResultStr = BaseUtil.byte2HexString(encrypted); * @return the data to be encrypted.
return encrypted; */
public byte[] encryptCTR(byte[] key, byte[] data, byte[] iVec) {
if (key == null) {
LOGGER.severe(Common.addTag("Key is null"));
return new byte[0];
}
if (data == null) {
LOGGER.severe(Common.addTag("data is null"));
return new byte[0];
}
if (iVec == null || iVec.length != I_VEC_LEN) {
LOGGER.severe(Common.addTag("iVec is null or the length of iVec is not valid, it should be " + "I_VEC_LEN"
));
return new byte[0];
}
try {
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(cipherMod);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
return cipher.doFinal(data);
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
LOGGER.severe(Common.addTag("[encryptCTR] catch NoSuchAlgorithmException or " +
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
return new byte[0];
}
} }
public byte[] encryptCTR(byte[] key, byte[] data) throws Exception { /**
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM); * Defining the decrypt method.
Cipher cipher = Cipher.getInstance(CIPHER_MODE); *
IvParameterSpec iv = new IvParameterSpec(iVec); * @param key the Key.
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv); * @param encryptDataAddIv the data to be decrypted.
byte[] encrypted = cipher.doFinal(data); * @return the data to be decrypted.
return encrypted; */
public byte[] decrypt(byte[] key, byte[] encryptDataAddIv) {
if (key == null) {
LOGGER.severe(Common.addTag("Key is null"));
return new byte[0];
}
if (encryptDataAddIv == null) {
LOGGER.severe(Common.addTag("encryptDataAddIv is null"));
return new byte[0];
}
if (encryptDataAddIv.length <= I_VEC_LEN) {
LOGGER.severe(Common.addTag("the length of encryptDataAddIv is not valid: " + encryptDataAddIv.length +
", it should be > " + I_VEC_LEN));
return new byte[0];
}
try {
byte[] iVec = Arrays.copyOfRange(encryptDataAddIv, 0, I_VEC_LEN);
byte[] encryptData = Arrays.copyOfRange(encryptDataAddIv, I_VEC_LEN, encryptDataAddIv.length);
SecretKeySpec sKeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(cipherMod);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.DECRYPT_MODE, sKeySpec, iv);
return cipher.doFinal(encryptData);
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException or " +
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
return new byte[0];
}
} }
public byte[] decrypt(byte[] key, byte[] encryptData) throws Exception {
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
IvParameterSpec iv = new IvParameterSpec(iVec);
cipher.init(Cipher.DECRYPT_MODE, skeySpec, iv);
byte[] origin = cipher.doFinal(encryptData);
return origin;
}
} }

View File

@ -1,18 +1,19 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* <p> *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* <p> *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* <p> *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient.cipher; package com.mindspore.flclient.cipher;
import java.io.UnsupportedEncodingException; import java.io.UnsupportedEncodingException;
@ -21,14 +22,23 @@ import java.nio.charset.Charset;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
/**
* Define conversion methods between basic data types.
*
* @since 2021-06-30
*/
public class BaseUtil { public class BaseUtil {
private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'}; private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B',
'C', 'D', 'E', 'F'};
public BaseUtil() {
}
/**
* Convert byte[] to String in hexadecimal format.
*
* @param bytes the byte[] object.
* @return the String object converted from byte[].
*/
public static String byte2HexString(byte[] bytes) { public static String byte2HexString(byte[] bytes) {
if (null == bytes) { if (bytes == null) {
return null; return null;
} else if (bytes.length == 0) { } else if (bytes.length == 0) {
return ""; return "";
@ -36,14 +46,20 @@ public class BaseUtil {
char[] chars = new char[bytes.length * 2]; char[] chars = new char[bytes.length * 2];
for (int i = 0; i < bytes.length; ++i) { for (int i = 0; i < bytes.length; ++i) {
int b = bytes[i]; int byteNum = bytes[i];
chars[i * 2] = HEX_DIGITS[(b & 240) >> 4]; chars[i * 2] = HEX_DIGITS[(byteNum & 240) >> 4];
chars[i * 2 + 1] = HEX_DIGITS[b & 15]; chars[i * 2 + 1] = HEX_DIGITS[byteNum & 15];
} }
return new String(chars); return new String(chars);
} }
} }
/**
* Convert String in hexadecimal format to byte[].
*
* @param str the String object.
* @return the byte[] converted from String object.
*/
public static byte[] hexString2ByteArray(String str) { public static byte[] hexString2ByteArray(String str) {
int length = str.length() / 2; int length = str.length() / 2;
byte[] bytes = new byte[length]; byte[] bytes = new byte[length];
@ -58,8 +74,13 @@ public class BaseUtil {
return bytes; return bytes;
} }
/**
* Convert byte[] to BigInteger.
*
* @param bytes the byte[] object.
* @return the BigInteger object converted from byte[].
*/
public static BigInteger byteArray2BigInteger(byte[] bytes) { public static BigInteger byteArray2BigInteger(byte[] bytes) {
BigInteger bigInteger = BigInteger.ZERO; BigInteger bigInteger = BigInteger.ZERO;
for (int i = 0; i < bytes.length; ++i) { for (int i = 0; i < bytes.length; ++i) {
int intI = bytes[i]; int intI = bytes[i];
@ -72,6 +93,13 @@ public class BaseUtil {
return bigInteger; return bigInteger;
} }
/**
* Convert String to BigInteger.
*
* @param str the String object.
* @return the BigInteger object converted from String object.
* @throws UnsupportedEncodingException if the encoding is not supported.
*/
public static BigInteger string2BigInteger(String str) throws UnsupportedEncodingException { public static BigInteger string2BigInteger(String str) throws UnsupportedEncodingException {
StringBuilder res = new StringBuilder(); StringBuilder res = new StringBuilder();
byte[] bytes = String.valueOf(str).getBytes("UTF-8"); byte[] bytes = String.valueOf(str).getBytes("UTF-8");
@ -83,14 +111,20 @@ public class BaseUtil {
return bigInteger; return bigInteger;
} }
public static String bigInteger2String(BigInteger bigInteger) throws UnsupportedEncodingException { /**
* Convert BigInteger to String.
*
* @param bigInteger the BigInteger object.
* @return the String object converted from BigInteger.
*/
public static String bigInteger2String(BigInteger bigInteger) {
StringBuilder res = new StringBuilder(); StringBuilder res = new StringBuilder();
List<Integer> lists = new ArrayList<>(); List<Integer> lists = new ArrayList<>();
BigInteger bi = bigInteger; BigInteger bi = bigInteger;
BigInteger DIV = BigInteger.valueOf(256); BigInteger div = BigInteger.valueOf(256);
while (bi.compareTo(BigInteger.ZERO) > 0) { while (bi.compareTo(BigInteger.ZERO) > 0) {
lists.add(bi.mod(DIV).intValue()); lists.add(bi.mod(div).intValue());
bi = bi.divide(DIV); bi = bi.divide(div);
} }
for (int i = lists.size() - 1; i >= 0; --i) { for (int i = lists.size() - 1; i >= 0; --i) {
res.append((char) (int) (lists.get(i))); res.append((char) (int) (lists.get(i)));
@ -98,13 +132,19 @@ public class BaseUtil {
return res.toString(); return res.toString();
} }
public static byte[] bigInteger2byteArray(BigInteger bigInteger) throws UnsupportedEncodingException { /**
* Convert BigInteger to byte[].
*
* @param bigInteger the BigInteger object.
* @return the byte[] object converted from BigInteger.
*/
public static byte[] bigInteger2byteArray(BigInteger bigInteger) {
List<Integer> lists = new ArrayList<>(); List<Integer> lists = new ArrayList<>();
BigInteger bi = bigInteger; BigInteger bi = bigInteger;
BigInteger DIV = BigInteger.valueOf(256); BigInteger div = BigInteger.valueOf(256);
while (bi.compareTo(BigInteger.ZERO) > 0) { while (bi.compareTo(BigInteger.ZERO) > 0) {
lists.add(bi.mod(DIV).intValue()); lists.add(bi.mod(div).intValue());
bi = bi.divide(DIV); bi = bi.divide(div);
} }
byte[] res = new byte[lists.size()]; byte[] res = new byte[lists.size()];
for (int i = lists.size() - 1; i >= 0; --i) { for (int i = lists.size() - 1; i >= 0; --i) {
@ -113,13 +153,19 @@ public class BaseUtil {
return res; return res;
} }
/**
* Convert Integer to byte[].
*
* @param num the Integer object.
* @return the byte[] object converted from Integer.
*/
public static byte[] integer2byteArray(Integer num) { public static byte[] integer2byteArray(Integer num) {
List<Integer> lists = new ArrayList<>(); List<Integer> lists = new ArrayList<>();
Integer bi = num; Integer bi = num;
Integer DIV = 256; Integer div = 256;
while (bi > 0) { while (bi > 0) {
lists.add(bi % DIV); lists.add(bi % div);
bi = bi / DIV; bi = bi / div;
} }
byte[] res = new byte[lists.size()]; byte[] res = new byte[lists.size()];
for (int i = lists.size() - 1; i >= 0; --i) { for (int i = lists.size() - 1; i >= 0; --i) {
@ -128,8 +174,13 @@ public class BaseUtil {
return res; return res;
} }
/**
* Convert byte[] to Integer.
*
* @param bytes the byte[] object.
* @return the Integer object converted from byte[].
*/
public static Integer byteArray2Integer(byte[] bytes) { public static Integer byteArray2Integer(byte[] bytes) {
Integer num = 0; Integer num = 0;
for (int i = 0; i < bytes.length; ++i) { for (int i = 0; i < bytes.length; ++i) {
int intI = bytes[i]; int intI = bytes[i];

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,9 +14,13 @@
* limitations under the License. * limitations under the License.
*/ */
package com.mindspore.flclient.cipher; package com.mindspore.flclient.cipher;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.Common; import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus; import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLCommunication; import com.mindspore.flclient.FLCommunication;
@ -25,23 +29,27 @@ import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets; import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import com.mindspore.flclient.cipher.struct.EncryptShare; import com.mindspore.flclient.cipher.struct.EncryptShare;
import com.mindspore.flclient.cipher.struct.NewArray; import com.mindspore.flclient.cipher.struct.NewArray;
import mindspore.schema.GetClientList; import mindspore.schema.GetClientList;
import mindspore.schema.ResponseCode; import mindspore.schema.ResponseCode;
import mindspore.schema.ReturnClientList; import mindspore.schema.ReturnClientList;
import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.time.LocalDateTime;
import java.util.Arrays; import java.util.Arrays;
import java.util.Date;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME; /**
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN; * Define the serialization method, handle the response message returned from server for GetClientList request.
*
* @since 2021-06-30
*/
public class ClientListReq { public class ClientListReq {
private static final Logger LOGGER = Logger.getLogger(ClientListReq.class.toString()); private static final Logger LOGGER = Logger.getLogger(ClientListReq.class.toString());
private FLCommunication flCommunication; private FLCommunication flCommunication;
private String nextRequestTime; private String nextRequestTime;
private FLParameter flParameter = FLParameter.getInstance(); private FLParameter flParameter = FLParameter.getInstance();
@ -64,34 +72,63 @@ public class ClientListReq {
return retCode; return retCode;
} }
public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) { /**
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); * Send serialized request message of GetClientList to server.
*
* @param iteration current iteration of federated learning task.
* @param u3ClientList list of clients successfully requested in UpdateModel round.
* @param decryptSecretsList list to store to decrypted secret fragments.
* @param returnShareList List of returned secret fragments from server.
* @param cuvKeys Keys used to decrypt secret fragments.
* @return the status code corresponding to the response message.
*/
public FLClientStatus getClientList(int iteration, List<String> u3ClientList,
List<DecryptShareSecrets> decryptSecretsList,
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
FlatBufferBuilder builder = new FlatBufferBuilder(); FlatBufferBuilder builder = new FlatBufferBuilder();
int id = builder.createString(localFLParameter.getFlID()); int id = builder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString(); Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = builder.createString(dateTime); int time = builder.createString(dateTime);
int clientListRoot = GetClientList.createGetClientList(builder, id, iteration, time); int clientListRoot = GetClientList.createGetClientList(builder, id, iteration, time);
builder.finish(clientListRoot); builder.finish(clientListRoot);
byte[] msg = builder.sizedByteArray(); byte[] msg = builder.sizedByteArray();
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
try { try {
byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg); byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[getClientList] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[getClientList] the server is not ready now, need wait some time and " +
"request again"));
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
nextRequestTime = ""; nextRequestTime = "";
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
} }
ByteBuffer buffer = ByteBuffer.wrap(responseData); ByteBuffer buffer = ByteBuffer.wrap(responseData);
LOGGER.info(Common.addTag("getClientList responseData size: " + responseData.length));
ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer); ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer);
FLClientStatus status = judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys); return judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);
return status; } catch (IOException ex) {
} catch (Exception e) { LOGGER.severe(Common.addTag("[getClientList] unsolved error code in getClientList: catch IOException: " +
e.printStackTrace(); ex.getMessage()));
retCode = ResponseCode.RequestError;
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
public FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) { /**
* Analyze the serialization message returned from server and perform corresponding processing.
*
* @param bufData Serialized message returned from server.
* @param u3ClientList list of clients successfully requested in UpdateModel round.
* @param decryptSecretsList list to store decrypted secret fragments.
* @param returnShareList List of returned secret fragments from server.
* @param cuvKeys Keys used to decrypt secret fragments.
* @return the status code corresponding to the response message.
*/
private FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList,
List<DecryptShareSecrets> decryptSecretsList,
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
retCode = bufData.retcode(); retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************")); LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -109,18 +146,15 @@ public class ClientListReq {
String curFlId = bufData.clients(i); String curFlId = bufData.clients(i);
u3ClientList.add(curFlId); u3ClientList.add(curFlId);
} }
try { status = decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys);
decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys); return status;
} catch (Exception e) {
e.printStackTrace();
return FLClientStatus.FAILED;
}
return FLClientStatus.SUCCESS;
case (ResponseCode.SucNotReady): case (ResponseCode.SucNotReady):
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetClientList again!")); LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
"GetClientList again!"));
return FLClientStatus.WAIT; return FLClientStatus.WAIT;
case (ResponseCode.OutOfTime): case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob again")); LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob" +
" again"));
setNextRequestTime(bufData.nextReqTime()); setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
case (ResponseCode.RequestError): case (ResponseCode.RequestError):
@ -128,36 +162,66 @@ public class ClientListReq {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList")); LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
default: default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is invalid: " + retCode)); LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is " +
"invalid: " + retCode));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
public void decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) throws Exception { private FLClientStatus decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList,
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
decryptSecretsList.clear(); decryptSecretsList.clear();
int size = returnShareList.size(); int size = returnShareList.size();
if (size <= 0) {
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <returnShareList> is null"));
return FLClientStatus.FAILED;
}
if (cuvKeys.isEmpty()) {
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <cuvKeys> is null"));
return FLClientStatus.FAILED;
}
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
EncryptShare encryptShare = returnShareList.get(i); EncryptShare encryptShare = returnShareList.get(i);
String vFlID = encryptShare.getFlID(); String vFlID = encryptShare.getFlID();
byte[] share = encryptShare.getShare().getArray(); byte[] share = encryptShare.getShare().getArray();
byte[] iVecIn = new byte[IVEC_LEN]; if (!cuvKeys.containsKey(vFlID)) {
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), iVecIn, "CBC"); LOGGER.severe(Common.addTag("[PairWiseMask] the key <vFlID> is not in map <cuvKeys> "));
return FLClientStatus.FAILED;
}
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), "CBC");
byte[] decryptShare = aesEncrypt.decrypt(cuvKeys.get(vFlID), share); byte[] decryptShare = aesEncrypt.decrypt(cuvKeys.get(vFlID), share);
if (decryptShare == null || decryptShare.length == 0) {
LOGGER.severe(Common.addTag("[decryptSecretShares] the return byte[] is null, please check!"));
return FLClientStatus.FAILED;
}
if (decryptShare.length < 4) {
LOGGER.severe(Common.addTag("[decryptSecretShares] the returned decryptShare is not valid: length is " +
"not right, please check!"));
return FLClientStatus.FAILED;
}
int sSize = (int) decryptShare[0]; int sSize = (int) decryptShare[0];
int bSize = (int) decryptShare[1]; int bSize = (int) decryptShare[1];
int sIndexLen = (int) decryptShare[2]; int sIndexLen = (int) decryptShare[2];
int bIndexLen = (int) decryptShare[3]; int bIndexLen = (int) decryptShare[3];
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen)); if (decryptShare.length < (4 + sIndexLen + bIndexLen + sSize + bSize)) {
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen, 4 + sIndexLen + bIndexLen)); LOGGER.severe(Common.addTag("[decryptSecretShares] the returned decryptShare is not valid: length is " +
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen, 4 + sIndexLen + bIndexLen + sSize); "not right, please check!"));
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize, 4 + sIndexLen + bIndexLen + sSize + bSize); return FLClientStatus.FAILED;
}
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen,
4 + sIndexLen + bIndexLen + sSize);
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize,
4 + sIndexLen + bIndexLen + sSize + bSize);
NewArray<byte[]> sSkVu = new NewArray<>(); NewArray<byte[]> sSkVu = new NewArray<>();
sSkVu.setSize(sSize); sSkVu.setSize(sSize);
sSkVu.setArray(sSkUv); sSkVu.setArray(sSkUv);
NewArray bVu = new NewArray(); NewArray bVu = new NewArray();
bVu.setSize(bSize); bVu.setSize(bSize);
bVu.setArray(bUv); bVu.setArray(bUv);
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen));
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen,
4 + sIndexLen + bIndexLen));
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
decryptShareSecrets.setFlID(vFlID); decryptShareSecrets.setFlID(vFlID);
decryptShareSecrets.setSSkVu(sSkVu); decryptShareSecrets.setSSkVu(sSkVu);
decryptShareSecrets.setBVu(bVu); decryptShareSecrets.setBVu(bVu);
@ -165,5 +229,6 @@ public class ClientListReq {
decryptShareSecrets.setIndexB(bIndex); decryptShareSecrets.setIndexB(bIndex);
decryptSecretsList.add(decryptShareSecrets); decryptSecretsList.add(decryptShareSecrets);
} }
return FLClientStatus.SUCCESS;
} }
} }

View File

@ -1,6 +1,5 @@
/*
/** * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,6 +16,10 @@
package com.mindspore.flclient.cipher; package com.mindspore.flclient.cipher;
import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
import com.mindspore.flclient.Common;
import org.bouncycastle.crypto.digests.SHA256Digest; import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator; import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator;
import org.bouncycastle.crypto.params.KeyParameter; import org.bouncycastle.crypto.params.KeyParameter;
@ -25,39 +28,80 @@ import org.bouncycastle.math.ec.rfc7748.X25519;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.logging.Logger; import java.util.logging.Logger;
/**
* Generate public-private key pairs and DH Keys.
*
* @since 2021-06-30
*/
public class KEYAgreement { public class KEYAgreement {
private static final Logger LOGGER = Logger.getLogger(KEYAgreement.class.toString()); private static final Logger LOGGER = Logger.getLogger(KEYAgreement.class.toString());
private static final int PBKDF2_ITERATIONS = 10000; private static final int PBKDF2_ITERATIONS = 10000;
private static final int SALT_SIZE = 32;
private static final int HASH_BIT_SIZE = 256; private static final int HASH_BIT_SIZE = 256;
private static final int KEY_LEN = X25519.SCALAR_SIZE;
private SecureRandom random = new SecureRandom(); private SecureRandom random = Common.getSecureRandom();
/**
* Generate private Key.
*
* @return the private Key.
*/
public byte[] generatePrivateKey() { public byte[] generatePrivateKey() {
byte[] privateKey = new byte[KEY_LEN]; byte[] privateKey = new byte[KEY_LEN];
X25519.generatePrivateKey(random, privateKey); X25519.generatePrivateKey(random, privateKey);
return privateKey; return privateKey;
} }
public byte[] generatePublicKey(byte[] privatekey) { /**
* Use private Key to generate public Key.
*
* @param privateKey the private Key.
* @return the public Key.
*/
public byte[] generatePublicKey(byte[] privateKey) {
if (privateKey == null || privateKey.length == 0) {
LOGGER.severe(Common.addTag("privateKey is null"));
return new byte[0];
}
byte[] publicKey = new byte[KEY_LEN]; byte[] publicKey = new byte[KEY_LEN];
X25519.generatePublicKey(privatekey, 0, publicKey, 0); X25519.generatePublicKey(privateKey, 0, publicKey, 0);
return publicKey; return publicKey;
} }
public byte[] keyAgreement(byte[] privatekey, byte[] publicKey) { /**
* Use private Key and public Key to generate DH Key.
*
* @param privateKey the private Key.
* @param publicKey the public Key.
* @return the DH Key.
*/
public byte[] keyAgreement(byte[] privateKey, byte[] publicKey) {
if (privateKey == null || privateKey.length == 0) {
LOGGER.severe(Common.addTag("privateKey is null"));
return new byte[0];
}
if (publicKey == null || publicKey.length == 0) {
LOGGER.severe(Common.addTag("publicKey is null"));
return new byte[0];
}
byte[] secret = new byte[KEY_LEN]; byte[] secret = new byte[KEY_LEN];
X25519.calculateAgreement(privatekey, 0, publicKey, 0, secret, 0); X25519.calculateAgreement(privateKey, 0, publicKey, 0, secret, 0);
return secret; return secret;
} }
/**
* Encrypt DH Key.
*
* @param password the DH Key.
* @param salt the salt value.
* @return encrypted DH Key.
*/
public byte[] getEncryptedPassword(byte[] password, byte[] salt) { public byte[] getEncryptedPassword(byte[] password, byte[] salt) {
if (password == null || password.length == 0) {
byte[] saltB = new byte[SALT_SIZE]; LOGGER.severe(Common.addTag("password is null"));
return new byte[0];
}
PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest()); PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest());
gen.init(password, saltB, PBKDF2_ITERATIONS); gen.init(password, salt, PBKDF2_ITERATIONS);
byte[] dk = ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey(); return ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey();
return dk;
} }
} }

View File

@ -0,0 +1,115 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import com.mindspore.flclient.Common;
import java.security.SecureRandom;
import java.util.List;
import java.util.logging.Logger;
/**
* Define the basic method for generating pairwise mask and individual mask.
*
* @since 2021-06-30
*/
public class Masking {
private static final Logger LOGGER = Logger.getLogger(Masking.class.toString());
/**
* Random generate RNG algorithm name.
*/
private static final String RNG_ALGORITHM = "SHA1PRNG";
/**
* Generate individual mask.
*
* @param secret used to store individual mask.
* @return the int value, 0 indicates success, -1 indicates failed .
*/
public int getRandomBytes(byte[] secret) {
if (secret == null || secret.length == 0) {
LOGGER.severe(Common.addTag("[Masking] the input argument <secret> is null, please check!"));
return -1;
}
SecureRandom secureRandom = Common.getSecureRandom();
secureRandom.nextBytes(secret);
return 0;
}
/**
* Generate pairwise mask.
*
* @param noise used to store pairwise mask.
* @param length the length of individual mask.
* @param seed the seed for generate pairwise mask.
* @param iVec the IV value.
* @return the int value, 0 indicates success, -1 indicates failed .
*/
public int getMasking(List<Float> noise, int length, byte[] seed, byte[] iVec) {
if (length <= 0) {
LOGGER.severe(Common.addTag("[Masking] the input argument <length> is not valid: <= 0, please check!"));
return -1;
}
int intV = Integer.SIZE / 8;
int size = length * intV;
byte[] data = new byte[size];
for (int i = 0; i < size; i++) {
data[i] = 0;
}
AESEncrypt aesEncrypt = new AESEncrypt(seed, "CTR");
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data, iVec);
if (encryptCtr == null || encryptCtr.length == 0) {
LOGGER.severe(Common.addTag("[Masking] the return byte[] is null, please check!"));
return -1;
}
for (int i = 0; i < length; i++) {
int[] sub = new int[intV];
for (int j = 0; j < 4; j++) {
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
}
int subI = byte2int(sub, 4);
if (subI == -1) {
LOGGER.severe(Common.addTag("[Masking] the the returned <subI> is not valid: -1, please check!"));
return -1;
}
Float fNoise = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
noise.add(fNoise);
}
return 0;
}
private static int byte2int(int[] data, int number) {
if (data.length < 4) {
LOGGER.severe(Common.addTag("[Masking] the input argument <data> is not valid: length < 4, please check!"));
return -1;
}
switch (number) {
case 1:
return (int) data[0];
case 2:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
case 3:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
case 4:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
| (data[3] << 24 & 0xff000000);
default:
return 0;
}
}
}

View File

@ -1,82 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mindspore.flclient.cipher;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.List;
import java.util.logging.Logger;
public class Random {
/**
* random generate RNG algorithm name
*/
private static final Logger LOGGER = Logger.getLogger(Random.class.toString());
private static final String RNG_ALGORITHM = "SHA1PRNG";
private static final int RANDOM_LEN = 128 / 8;
public void getRandomBytes(byte[] secret) {
try {
SecureRandom secureRandom = SecureRandom.getInstance("SHA1PRNG");
secureRandom.nextBytes(secret);
} catch (NoSuchAlgorithmException e) {
e.printStackTrace();
}
}
public void randomAESCTR(List<Float> noise, int length, byte[] seed) throws Exception {
int intV = Integer.SIZE / 8;
int size = length * intV;
byte[] data = new byte[size];
for (int i = 0; i < size; i++) {
data[i] = 0;
}
byte[] ivec = new byte[RANDOM_LEN];
AESEncrypt aesEncrypt = new AESEncrypt(seed, ivec, "CTR");
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data);
for (int i = 0; i < length; i++) {
int[] sub = new int[intV];
for (int j = 0; j < 4; j++) {
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
}
int subI = byte2int(sub, 4);
Float f = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
noise.add(f);
}
}
public static int byte2int(int[] data, int n) {
switch (n) {
case 1:
return (int) data[0];
case 2:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
case 3:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
case 4:
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
| (data[3] << 24 & 0xff000000);
default:
return 0;
}
}
}

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,23 +16,33 @@
package com.mindspore.flclient.cipher; package com.mindspore.flclient.cipher;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import com.mindspore.flclient.Common; import com.mindspore.flclient.Common;
import com.mindspore.flclient.FLClientStatus; import com.mindspore.flclient.FLClientStatus;
import com.mindspore.flclient.FLCommunication; import com.mindspore.flclient.FLCommunication;
import com.mindspore.flclient.FLParameter; import com.mindspore.flclient.FLParameter;
import com.mindspore.flclient.LocalFLParameter; import com.mindspore.flclient.LocalFLParameter;
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets; import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
import mindspore.schema.ClientShare;
import mindspore.schema.ResponseCode;
import mindspore.schema.ClientShare;
import mindspore.schema.ReconstructSecret;
import mindspore.schema.ResponseCode;
import mindspore.schema.SendReconstructSecret;
import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.time.LocalDateTime; import java.util.Date;
import java.util.List; import java.util.List;
import java.util.logging.Logger; import java.util.logging.Logger;
import static com.mindspore.flclient.FLParameter.SLEEP_TIME; /**
* reconstruct secret request
*
* @since 2021-8-27
*/
public class ReconstructSecretReq { public class ReconstructSecretReq {
private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString()); private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString());
private FLCommunication flCommunication; private FLCommunication flCommunication;
@ -41,36 +51,44 @@ public class ReconstructSecretReq {
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance(); private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
private int retCode; private int retCode;
public String getNextRequestTime() { /**
return nextRequestTime; * reconstruct secret request
} */
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
public int getRetCode() {
return retCode;
}
public ReconstructSecretReq() { public ReconstructSecretReq() {
flCommunication = FLCommunication.getInstance(); flCommunication = FLCommunication.getInstance();
} }
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList, List<String> u3ClientList, int iteration) { /**
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum()); * send secret shards to server
*
* @param decryptShareSecretsList secret shards list
* @param u3ClientList u3 client list
* @param iteration iter number
* @return request result
*/
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList,
List<String> u3ClientList, int iteration) {
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
FlatBufferBuilder builder = new FlatBufferBuilder(); FlatBufferBuilder builder = new FlatBufferBuilder();
int desFlId = builder.createString(localFLParameter.getFlID()); int desFlId = builder.createString(localFLParameter.getFlID());
String dateTime = LocalDateTime.now().toString(); Date date = new Date();
long timestamp = date.getTime();
String dateTime = String.valueOf(timestamp);
int time = builder.createString(dateTime); int time = builder.createString(dateTime);
int shareSecretsSize = decryptShareSecretsList.size(); int shareSecretsSize = decryptShareSecretsList.size();
if (shareSecretsSize <= 0) { if (shareSecretsSize <= 0) {
LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please waite.")); LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please " +
"waite."));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} else { } else {
int[] decryptShareList = new int[shareSecretsSize]; int[] decryptShareList = new int[shareSecretsSize];
for (int i = 0; i < shareSecretsSize; i++) { for (int i = 0; i < shareSecretsSize; i++) {
DecryptShareSecrets decryptShareSecrets = decryptShareSecretsList.get(i); DecryptShareSecrets decryptShareSecrets = decryptShareSecretsList.get(i);
if (decryptShareSecrets.getFlID() == null) {
LOGGER.severe(Common.addTag("[PairWiseMask] get remote flID failed!"));
return FLClientStatus.FAILED;
}
String srcFlId = decryptShareSecrets.getFlID(); String srcFlId = decryptShareSecrets.getFlID();
byte[] share; byte[] share;
int index; int index;
@ -86,31 +104,33 @@ public class ReconstructSecretReq {
int clientShare = ClientShare.createClientShare(builder, fbsSrcFlId, fbsShare, index); int clientShare = ClientShare.createClientShare(builder, fbsSrcFlId, fbsShare, index);
decryptShareList[i] = clientShare; decryptShareList[i] = clientShare;
} }
int reconstructShareSecrets = mindspore.schema.SendReconstructSecret.createReconstructSecretSharesVector(builder, decryptShareList); int reconstructShareSecrets = SendReconstructSecret.createReconstructSecretSharesVector(builder,
int reconstructSecretRoot = mindspore.schema.SendReconstructSecret.createSendReconstructSecret(builder, desFlId, reconstructShareSecrets, iteration, time); decryptShareList);
int reconstructSecretRoot = SendReconstructSecret.createSendReconstructSecret(builder, desFlId,
reconstructShareSecrets, iteration, time);
builder.finish(reconstructSecretRoot); builder.finish(reconstructSecretRoot);
byte[] msg = builder.sizedByteArray(); byte[] msg = builder.sizedByteArray();
try { try {
byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg); byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg);
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) { if (!Common.isSeverReady(responseData)) {
LOGGER.info(Common.addTag("[sendReconstructSecret] The cluster is in safemode, need wait some time and request again")); LOGGER.info(Common.addTag("[sendReconstructSecret] the server is not ready now, need wait some " +
"time and request again"));
Common.sleep(SLEEP_TIME); Common.sleep(SLEEP_TIME);
nextRequestTime = ""; nextRequestTime = "";
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
} }
ByteBuffer buffer = ByteBuffer.wrap(responseData); ByteBuffer buffer = ByteBuffer.wrap(responseData);
mindspore.schema.ReconstructSecret reconstructSecretRsp = mindspore.schema.ReconstructSecret.getRootAsReconstructSecret(buffer); ReconstructSecret reconstructSecretRsp = ReconstructSecret.getRootAsReconstructSecret(buffer);
FLClientStatus status = judgeSendReconstructSecrets(reconstructSecretRsp); return judgeSendReconstructSecrets(reconstructSecretRsp);
return status; } catch (IOException ex) {
} catch (Exception e) {
LOGGER.severe(Common.addTag("[PairWiseMask] un solved error code in reconstruct")); LOGGER.severe(Common.addTag("[PairWiseMask] un solved error code in reconstruct"));
e.printStackTrace(); ex.printStackTrace();
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
} }
public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) { private FLClientStatus judgeSendReconstructSecrets(ReconstructSecret bufData) {
retCode = bufData.retcode(); retCode = bufData.retcode();
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************")); LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************"));
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode)); LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
@ -122,7 +142,8 @@ public class ReconstructSecretReq {
LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success")); LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success"));
return FLClientStatus.SUCCESS; return FLClientStatus.SUCCESS;
case (ResponseCode.OutOfTime): case (ResponseCode.OutOfTime):
LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request startFLJob again")); LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request " +
"startFLJob again"));
setNextRequestTime(bufData.nextReqTime()); setNextRequestTime(bufData.nextReqTime());
return FLClientStatus.RESTART; return FLClientStatus.RESTART;
case (ResponseCode.RequestError): case (ResponseCode.RequestError):
@ -130,8 +151,36 @@ public class ReconstructSecretReq {
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets")); LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets"));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
default: default:
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is invalid: " + retCode)); LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is " +
"invalid: " + retCode));
return FLClientStatus.FAILED; return FLClientStatus.FAILED;
} }
} }
/**
* get next request time
*
* @return next request time
*/
public String getNextRequestTime() {
return nextRequestTime;
}
/**
* set next request time
*
* @param nextRequestTime next request time
*/
public void setNextRequestTime(String nextRequestTime) {
this.nextRequestTime = nextRequestTime;
}
/**
* get retCode
*
* @return retCode
*/
public int getRetCode() {
return retCode;
}
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,115 +22,158 @@ import java.math.BigInteger;
import java.util.Random; import java.util.Random;
import java.util.logging.Logger; import java.util.logging.Logger;
/**
* Define functions that for splitting secret and combining secret shards.
*
* @since 2021-06-30
*/
public class ShareSecrets { public class ShareSecrets {
private static final Logger LOGGER = Logger.getLogger(ShareSecrets.class.toString()); private static final Logger LOGGER = Logger.getLogger(ShareSecrets.class.toString());
public final class SecretShare { private BigInteger prime;
public SecretShare(final int num, final BigInteger share) { private final int minNum;
this.num = num; private final int totalNum;
this.share = share; private final Random random;
}
public int getNum() { /**
return num; * Defines the constructor of the class ShareSecrets.
*
* @param minNum minimum number of fragments required to reconstruct a secret.
* @param totalNum total clients number.
*/
public ShareSecrets(final int minNum, final int totalNum) {
if (minNum <= 0) {
LOGGER.severe(Common.addTag("the argument <k> is not valid: <= 0, it should be > 0"));
throw new IllegalArgumentException();
} }
if (totalNum <= 0) {
public BigInteger getShare() { LOGGER.severe(Common.addTag("the argument <n> is not valid: <= 0, it should be > 0"));
return share; throw new IllegalArgumentException();
} }
if (minNum > totalNum) {
@Override LOGGER.severe(Common.addTag("the argument <k, n> is not valid: k > n, it should k <= n"));
public String toString() { throw new IllegalArgumentException();
return "SecretShare [num=" + num + ", share=" + share + "]";
} }
this.minNum = minNum;
private final int num; this.totalNum = totalNum;
private final BigInteger share; random = Common.getSecureRandom();
} }
public ShareSecrets(final int k, final int n) { /**
this.k = k; * Splits a secret into a specified number of secret fragments.
this.n = n; *
* @param bytes the secret need to be split.
random = new Random(); * @param primeByte teh big prime number used to combine secret fragments.
} * @return the secret fragments.
*/
public SecretShare[] split(final byte[] bytes, byte[] primeByte) { public SecretShares[] split(final byte[] bytes, byte[] primeByte) {
if (bytes == null || bytes.length == 0) {
LOGGER.severe(Common.addTag("the input argument <bytes> is null"));
return new SecretShares[0];
}
if (primeByte == null || primeByte.length == 0) {
LOGGER.severe(Common.addTag("the input argument <primeByte> is null"));
return new SecretShares[0];
}
BigInteger secret = BaseUtil.byteArray2BigInteger(bytes); BigInteger secret = BaseUtil.byteArray2BigInteger(bytes);
final int modLength = secret.bitLength() + 1; final int modLength = secret.bitLength() + 1;
prime = BaseUtil.byteArray2BigInteger(primeByte); prime = BaseUtil.byteArray2BigInteger(primeByte);
final BigInteger[] coeff = new BigInteger[k - 1]; final BigInteger[] coefficient = new BigInteger[minNum - 1];
LOGGER.info(Common.addTag("Prime Number: " + prime)); for (int i = 0; i < minNum - 1; i++) {
coefficient[i] = randomZp(prime);
for (int i = 0; i < k - 1; i++) {
coeff[i] = randomZp(prime);
LOGGER.info(Common.addTag("a" + (i + 1) + ": " + coeff[i]));
} }
final SecretShare[] shares = new SecretShare[n]; final SecretShares[] shares = new SecretShares[totalNum];
for (int i = 1; i <= n; i++) { for (int i = 1; i <= totalNum; i++) {
BigInteger accum = secret; BigInteger accumulate = secret;
for (int j = 1; j < k; j++) { for (int j = 1; j < minNum; j++) {
final BigInteger t1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime); final BigInteger b1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime);
final BigInteger t2 = coeff[j - 1].multiply(t1).mod(prime); final BigInteger b2 = coefficient[j - 1].multiply(b1).mod(prime);
accum = accum.add(t2).mod(prime); accumulate = accumulate.add(b2).mod(prime);
} }
shares[i - 1] = new SecretShare(i, accum); shares[i - 1] = new SecretShares(i, accumulate);
LOGGER.info(Common.addTag("Share " + shares[i - 1]));
} }
return shares; return shares;
} }
public BigInteger getPrime() { /**
return prime; * Combine secret fragments.
} *
* @param shares the secret fragments.
public BigInteger combine(final SecretShare[] shares, final byte[] primeByte) { * @param primeByte teh big prime number used to combine secret fragments.
* @return the secrets combined by secret fragments.
*/
public BigInteger combine(final SecretShares[] shares, final byte[] primeByte) {
if (shares == null || shares.length == 0) {
LOGGER.severe(Common.addTag("the input argument <shares> is null"));
return BigInteger.ZERO;
}
if (primeByte == null || primeByte.length == 0) {
LOGGER.severe(Common.addTag("the input argument <primeByte> is null"));
return BigInteger.ZERO;
}
BigInteger primeNum = BaseUtil.byteArray2BigInteger(primeByte); BigInteger primeNum = BaseUtil.byteArray2BigInteger(primeByte);
BigInteger accum = BigInteger.ZERO; BigInteger accumulate = BigInteger.ZERO;
for (int j = 0; j < k; j++) { for (int j = 0; j < minNum; j++) {
BigInteger num = BigInteger.ONE; BigInteger num = BigInteger.ONE;
BigInteger den = BigInteger.ONE; BigInteger den = BigInteger.ONE;
BigInteger tmp; BigInteger tmp;
for (int m = 0; m < k; m++) { for (int m = 0; m < minNum; m++) {
if (j != m) { if (j != m) {
num = num.multiply(BigInteger.valueOf(shares[m].getNum())).mod(primeNum); num = num.multiply(BigInteger.valueOf(shares[m].getNumber())).mod(primeNum);
tmp = BigInteger.valueOf(shares[j].getNum()).multiply(BigInteger.valueOf(-1)); tmp = BigInteger.valueOf(shares[j].getNumber()).multiply(BigInteger.valueOf(-1));
tmp = BigInteger.valueOf(shares[m].getNum()).add(tmp).mod(primeNum); tmp = BigInteger.valueOf(shares[m].getNumber()).add(tmp).mod(primeNum);
den = den.multiply(tmp).mod(primeNum); den = den.multiply(tmp).mod(primeNum);
} }
} }
final BigInteger value = shares[j].getShare(); final BigInteger value = shares[j].getShares();
tmp = den.modInverse(primeNum); tmp = den.modInverse(primeNum);
tmp = tmp.multiply(num).mod(primeNum); tmp = tmp.multiply(num).mod(primeNum);
tmp = tmp.multiply(value).mod(primeNum); tmp = tmp.multiply(value).mod(primeNum);
accum = accum.add(tmp).mod(primeNum); accumulate = accumulate.add(tmp).mod(primeNum);
LOGGER.info(Common.addTag("value: " + value + ", tmp: " + tmp + ", accum: " + accum));
} }
LOGGER.info(Common.addTag("The secret is: " + accum)); return accumulate;
return accum;
} }
private BigInteger randomZp(final BigInteger p) { private BigInteger randomZp(final BigInteger num) {
while (true) { while (true) {
final BigInteger r = new BigInteger(p.bitLength(), random); final BigInteger rand = new BigInteger(num.bitLength(), random);
if (r.compareTo(BigInteger.ZERO) > 0 && r.compareTo(p) < 0) { if (rand.compareTo(BigInteger.ZERO) > 0 && rand.compareTo(num) < 0) {
return r; return rand;
} }
} }
} }
private BigInteger prime; /**
private final int k; * Define the structure for store secret fragments.
private final int n; */
private final Random random; public final class SecretShares {
private final int SECRET_MAX_LEN = 32; private final int number;
private final BigInteger share;
public SecretShares(final int number, final BigInteger share) {
this.number = number;
this.share = share;
}
public int getNumber() {
return number;
}
public BigInteger getShares() {
return share;
}
@Override
public String toString() {
return "SecretShares [number=" + number + ", share=" + share + "]";
}
}
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,33 +16,114 @@
package com.mindspore.flclient.cipher.struct; package com.mindspore.flclient.cipher.struct;
import com.mindspore.flclient.Common;
import java.util.logging.Logger;
/**
* public key class of secure aggregation
*
* @since 2021-8-27
*/
public class ClientPublicKey { public class ClientPublicKey {
private static final Logger LOGGER = Logger.getLogger(ClientPublicKey.class.toString());
private String flID; private String flID;
private NewArray<byte[]> cPK; private NewArray<byte[]> cPK;
private NewArray<byte[]> sPk; private NewArray<byte[]> sPk;
private NewArray<byte[]> pwIv;
private NewArray<byte[]> pwSalt;
/**
* get client's flID
*
* @return flID of this client
*/
public String getFlID() { public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of <flID> is null, please set it before use"));
throw new IllegalArgumentException();
}
return flID; return flID;
} }
/**
* set client's flID
*
* @param flID hash value used for identify client
*/
public void setFlID(String flID) { public void setFlID(String flID) {
this.flID = flID; this.flID = flID;
} }
/**
* get CPK of secure aggregation
*
* @return CPK of secure aggregation
*/
public NewArray<byte[]> getCPK() { public NewArray<byte[]> getCPK() {
return cPK; return cPK;
} }
/**
* set CPK of secure aggregation
*
* @param cPK public key used for encryption
*/
public void setCPK(NewArray<byte[]> cPK) { public void setCPK(NewArray<byte[]> cPK) {
this.cPK = cPK; this.cPK = cPK;
} }
/**
* get SPK of secure aggregation
*
* @return SPK of secure aggregation
*/
public NewArray<byte[]> getSPK() { public NewArray<byte[]> getSPK() {
return sPk; return sPk;
} }
/**
* set SPK of secure aggregation
*
* @param sPk public key used for encryption
*/
public void setSPK(NewArray<byte[]> sPk) { public void setSPK(NewArray<byte[]> sPk) {
this.sPk = sPk; this.sPk = sPk;
} }
/**
* get the IV value used for pairwise encrypt
*
* @return the IV value used for pairwise encrypt
*/
public NewArray<byte[]> getPwIv() {
return pwIv;
}
/**
* set the IV value used for pairwise encrypt
*
* @param pwIv IV value used for pairwise encrypt
*/
public void setPwIv(NewArray<byte[]> pwIv) {
this.pwIv = pwIv;
}
/**
* get salt value for secure aggregation
*
* @return salt value for secure aggregation
*/
public NewArray<byte[]> getPwSalt() {
return pwSalt;
}
/**
* set salt value for secure aggregation
*
* @param pwSalt salt value for secure aggregation
*/
public void setPwSalt(NewArray<byte[]> pwSalt) {
this.pwSalt = pwSalt;
}
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,49 +16,114 @@
package com.mindspore.flclient.cipher.struct; package com.mindspore.flclient.cipher.struct;
import com.mindspore.flclient.Common;
import java.util.logging.Logger;
/**
* class used for set and get decryption shards
*
* @since 2021-8-27
*/
public class DecryptShareSecrets { public class DecryptShareSecrets {
private static final Logger LOGGER = Logger.getLogger(DecryptShareSecrets.class.toString());
private String flID; private String flID;
private NewArray<byte[]> sSkVu; private NewArray<byte[]> sSkVu;
private NewArray<byte[]> bVu; private NewArray<byte[]> bVu;
private int sIndex; private int sIndex;
private int indexB; private int indexB;
/**
* get flID of client
*
* @return flID of this client
*/
public String getFlID() { public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
"use"));
throw new IllegalArgumentException();
}
return flID; return flID;
} }
/**
* set flID for this client
*
* @param flID hash value used for identify client
*/
public void setFlID(String flID) { public void setFlID(String flID) {
this.flID = flID; this.flID = flID;
} }
/**
* get secret key shards
*
* @return secret key shards
*/
public NewArray<byte[]> getSSkVu() { public NewArray<byte[]> getSSkVu() {
return sSkVu; return sSkVu;
} }
/**
* set secret key shards
*
* @param sSkVu secret key shards
*/
public void setSSkVu(NewArray<byte[]> sSkVu) { public void setSSkVu(NewArray<byte[]> sSkVu) {
this.sSkVu = sSkVu; this.sSkVu = sSkVu;
} }
/**
* get bu shards
*
* @return bu shards
*/
public NewArray<byte[]> getBVu() { public NewArray<byte[]> getBVu() {
return bVu; return bVu;
} }
/**
* set bu shards
*
* @param bVu bu shards used for secure aggregation
*/
public void setBVu(NewArray<byte[]> bVu) { public void setBVu(NewArray<byte[]> bVu) {
this.bVu = bVu; this.bVu = bVu;
} }
/**
* get index of secret shards
*
* @return index of secret shards
*/
public int getSIndex() { public int getSIndex() {
return sIndex; return sIndex;
} }
/**
* set index of secret shards
*
* @param sIndex index of secret shards
*/
public void setSIndex(int sIndex) { public void setSIndex(int sIndex) {
this.sIndex = sIndex; this.sIndex = sIndex;
} }
/**
* get index of bu shards
*
* @return index of bu shards
*/
public int getIndexB() { public int getIndexB() {
return indexB; return indexB;
} }
/**
* set index of bu shards
*
* @param indexB index of bu shards
*/
public void setIndexB(int indexB) { public void setIndexB(int indexB) {
this.indexB = indexB; this.indexB = indexB;
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,22 +16,57 @@
package com.mindspore.flclient.cipher.struct; package com.mindspore.flclient.cipher.struct;
import com.mindspore.flclient.Common;
import java.util.logging.Logger;
/**
* class used for encrypt shares of secret
*
* @since 2021-8-27
*/
public class EncryptShare { public class EncryptShare {
private static final Logger LOGGER = Logger.getLogger(DecryptShareSecrets.class.toString());
private String flID; private String flID;
private NewArray<byte[]> share; private NewArray<byte[]> share;
/**
* get client's flID
*
* @return flID of this client
*/
public String getFlID() { public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
"use"));
throw new IllegalArgumentException();
}
return flID; return flID;
} }
/**
* set client's flID
*
* @param flID hash value used for identify client
*/
public void setFlID(String flID) { public void setFlID(String flID) {
this.flID = flID; this.flID = flID;
} }
/**
* get secret share
*
* @return secret share
*/
public NewArray<byte[]> getShare() { public NewArray<byte[]> getShare() {
return share; return share;
} }
/**
* set secret share
*
* @param share secret share
*/
public void setShare(NewArray<byte[]> share) { public void setShare(NewArray<byte[]> share) {
this.share = share; this.share = share;
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,24 +16,50 @@
package com.mindspore.flclient.cipher.struct; package com.mindspore.flclient.cipher.struct;
/**
* class used define new array type
*
* @param <T> an array
*
* @since 2021-8-27
*/
public class NewArray<T> { public class NewArray<T> {
private int size; private int size;
private T array; private T array;
/**
* get array size
*
* @return array size
*/
public int getSize() { public int getSize() {
return size; return size;
} }
/**
* set array size
*
* @param size array size
*/
public void setSize(int size) { public void setSize(int size) {
this.size = size; this.size = size;
} }
/**
* get array
*
* @return an array
*/
public T getArray() { public T getArray() {
return array; return array;
} }
/**
* set array
*
* @param array input
*/
public void setArray(T array) { public void setArray(T array) {
this.array = array; this.array = array;
} }
} }

View File

@ -1,5 +1,5 @@
/** /*
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -16,31 +16,75 @@
package com.mindspore.flclient.cipher.struct; package com.mindspore.flclient.cipher.struct;
import com.mindspore.flclient.Common;
import java.util.logging.Logger;
/**
* share secret class
*
* @since 2021-8-27
*/
public class ShareSecret { public class ShareSecret {
private static final Logger LOGGER = Logger.getLogger(ShareSecret.class.toString());
private String flID; private String flID;
private NewArray<byte[]> share; private NewArray<byte[]> share;
private int index; private int index;
/**
* get client's flID
*
* @return flID of this client
*/
public String getFlID() { public String getFlID() {
if (flID == null || flID.isEmpty()) {
LOGGER.severe(Common.addTag("[ShareSecret] the parameter of <flID> is null, please set it before use"));
throw new IllegalArgumentException();
}
return flID; return flID;
} }
/**
* set flID for this client
*
* @param flID hash value used for identify client
*/
public void setFlID(String flID) { public void setFlID(String flID) {
this.flID = flID; this.flID = flID;
} }
/**
* get secret share
*
* @return secret share
*/
public NewArray<byte[]> getShare() { public NewArray<byte[]> getShare() {
return share; return share;
} }
/**
* set secret share
*
* @param share secret shares
*/
public void setShare(NewArray<byte[]> share) { public void setShare(NewArray<byte[]> share) {
this.share = share; this.share = share;
} }
/**
* get secret index
*
* @return secret index
*/
public int getIndex() { public int getIndex() {
return index; return index;
} }
/**
* set secret index
*
* @param index secret index
*/
public void setIndex(int index) { public void setIndex(int index) {
this.index = index; this.index = index;
} }

View File

@ -31,6 +31,8 @@ table ClientPublicKeys {
fl_id:string; fl_id:string;
c_pk:[ubyte]; c_pk:[ubyte];
s_pk: [ubyte]; s_pk: [ubyte];
pw_iv: [ubyte];
pw_salt: [ubyte];
} }
table ClientShare { table ClientShare {
@ -45,6 +47,9 @@ table RequestExchangeKeys{
s_pk:[ubyte]; s_pk:[ubyte];
iteration:int; iteration:int;
timestamp:string; timestamp:string;
ind_iv:[ubyte];
pw_iv:[ubyte];
pw_salt:[ubyte];
} }
table ResponseExchangeKeys{ table ResponseExchangeKeys{