forked from mindspore-Ecosystem/mindspore
Fix security problems and code-check problems for federated's secure aggregation
fix security check problems for flclient
This commit is contained in:
parent
0abff9ad65
commit
7d9dd343f3
|
@ -39,7 +39,7 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/key_agreement.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/random.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/masking.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/secret_sharing.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_init.cc")
|
||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_keys.cc")
|
||||
|
|
|
@ -22,26 +22,28 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
||||
bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_initial_client_cnt,
|
||||
size_t cipher_exchange_secrets_cnt, size_t cipher_share_secrets_cnt,
|
||||
bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
|
||||
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
|
||||
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
|
||||
size_t cipher_reconstruct_secrets_up_cnt) {
|
||||
MS_LOG(INFO) << "CipherInit::Init START";
|
||||
int return_num = 0;
|
||||
return_num = memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN);
|
||||
if (return_num != 0) {
|
||||
if (memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN) != 0) {
|
||||
MS_LOG(ERROR) << "CipherInit::memory copy failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
publicparam_.g = param.g;
|
||||
publicparam_.t = param.t;
|
||||
secrets_minnums_ = param.t;
|
||||
client_num_need_ = cipher_initial_client_cnt;
|
||||
featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float);
|
||||
share_clients_num_need_ = cipher_share_secrets_cnt;
|
||||
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
|
||||
get_model_num_need_ = cipher_get_clientlist_cnt;
|
||||
|
||||
exchange_key_threshold = cipher_exchange_keys_cnt;
|
||||
get_key_threshold = cipher_get_keys_cnt;
|
||||
share_secrets_threshold = cipher_share_secrets_cnt;
|
||||
get_secrets_threshold = cipher_get_secrets_cnt;
|
||||
client_list_threshold = cipher_get_clientlist_cnt;
|
||||
reconstruct_secrets_threshold = cipher_reconstruct_secrets_up_cnt;
|
||||
|
||||
time_out_mutex_ = time_out_mutex;
|
||||
publicparam_.dp_eps = param.dp_eps;
|
||||
publicparam_.dp_delta = param.dp_delta;
|
||||
|
@ -62,10 +64,12 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
|||
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << " CipherInit client_num_need_ : " << client_num_need_;
|
||||
MS_LOG(INFO) << " CipherInit share_clients_num_need_ : " << share_clients_num_need_;
|
||||
MS_LOG(INFO) << " CipherInit reconstruct_clients_num_need_ : " << reconstruct_clients_num_need_;
|
||||
MS_LOG(INFO) << " CipherInit get_model_num_need_ : " << get_model_num_need_;
|
||||
MS_LOG(INFO) << " CipherInit exchange_key_threshold : " << exchange_key_threshold;
|
||||
MS_LOG(INFO) << " CipherInit get_key_threshold : " << get_key_threshold;
|
||||
MS_LOG(INFO) << " CipherInit share_secrets_threshold : " << share_secrets_threshold;
|
||||
MS_LOG(INFO) << " CipherInit get_secrets_threshold : " << get_secrets_threshold;
|
||||
MS_LOG(INFO) << " CipherInit client_list_threshold : " << client_list_threshold;
|
||||
MS_LOG(INFO) << " CipherInit reconstruct_secrets_threshold : " << reconstruct_secrets_threshold;
|
||||
MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_;
|
||||
if (!Check_Parames()) {
|
||||
MS_LOG(ERROR) << "Cipher parameters are illegal.";
|
||||
|
@ -82,11 +86,10 @@ bool CipherInit::Check_Parames() {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (share_clients_num_need_ < reconstruct_clients_num_need_) {
|
||||
MS_LOG(ERROR)
|
||||
<< "reconstruct_clients_num_need (which is reconstruct_secrets_threshold + 1) should not be larger "
|
||||
"than share_clients_num_need (which is start_fl_job_threshold*share_secrets_ratio), but got they are:"
|
||||
<< reconstruct_clients_num_need_ << ", " << share_clients_num_need_;
|
||||
if (share_secrets_threshold < reconstruct_secrets_threshold) {
|
||||
MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger "
|
||||
"than share_secrets_threshold, but got they are:"
|
||||
<< reconstruct_secrets_threshold << ", " << share_secrets_threshold;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -29,17 +29,6 @@
|
|||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
||||
template <typename T1>
|
||||
bool CreateArray(std::vector<T1> *newData, const flatbuffers::Vector<T1> &fbs_arr) {
|
||||
size_t size = newData->size();
|
||||
size_t size_fbs_arr = fbs_arr.size();
|
||||
if (size != size_fbs_arr) return false;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
newData->at(i) = fbs_arr.Get(i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Initialization of secure aggregation.
|
||||
class CipherInit {
|
||||
public:
|
||||
|
@ -49,9 +38,10 @@ class CipherInit {
|
|||
}
|
||||
|
||||
// Initialize the parameters of the secure aggregation.
|
||||
bool Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_initial_client_cnt,
|
||||
size_t cipher_exchange_secrets_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_clientlist_cnt,
|
||||
size_t cipher_reconstruct_secrets_down_cnt, size_t cipher_reconstruct_secrets_up_cnt);
|
||||
bool Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
|
||||
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
|
||||
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
|
||||
size_t cipher_reconstruct_secrets_up_cnt);
|
||||
|
||||
// Check whether the parameters are valid.
|
||||
bool Check_Parames();
|
||||
|
@ -59,10 +49,12 @@ class CipherInit {
|
|||
// Get public params. which is given to start fl job thread.
|
||||
CipherPublicPara *GetPublicParams() { return &publicparam_; }
|
||||
|
||||
size_t share_clients_num_need_; // the minimum number of clients to share secrets.
|
||||
size_t reconstruct_clients_num_need_; // the minimum number of clients to reconstruct secret mask.
|
||||
size_t client_num_need_; // the minimum number of clients to update model.
|
||||
size_t get_model_num_need_; // the minimum number of clients to get model.
|
||||
size_t share_secrets_threshold; // the minimum number of clients to share secret fragments.
|
||||
size_t get_secrets_threshold; // the minimum number of clients to get secret fragments.
|
||||
size_t reconstruct_secrets_threshold; // the minimum number of clients to reconstruct secret mask.
|
||||
size_t exchange_key_threshold; // the minimum number of clients to send public keys.
|
||||
size_t get_key_threshold; // the minimum number of clients to get public keys.
|
||||
size_t client_list_threshold; // the minimum number of clients to get update model client list.
|
||||
|
||||
size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask.
|
||||
size_t featuremap_; // the size of data to deal.
|
||||
|
|
|
@ -21,203 +21,170 @@ namespace mindspore {
|
|||
namespace armour {
|
||||
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::GetExchangeKeys *get_exchange_keys_req,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &get_exchange_keys_resp_builder) {
|
||||
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
|
||||
MS_LOG(INFO) << "CipherMgr::GetKeys START";
|
||||
if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) {
|
||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false);
|
||||
if (get_exchange_keys_req == nullptr) {
|
||||
MS_LOG(ERROR) << "Request is nullptr";
|
||||
BuildGetKeysRsp(fbb, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false);
|
||||
return false;
|
||||
}
|
||||
|
||||
// get clientlist from memory server.
|
||||
std::vector<std::string> clients;
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &clients);
|
||||
|
||||
size_t cur_clients_num = clients.size();
|
||||
size_t cur_exchange_clients_num = client_public_keys.size();
|
||||
std::string fl_id = get_exchange_keys_req->fl_id()->str();
|
||||
|
||||
if (find(clients.begin(), clients.end(), fl_id) == clients.end()) {
|
||||
MS_LOG(INFO) << "The fl_id is not in clients.";
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
|
||||
if (cur_exchange_clients_num < cipher_init_->exchange_key_threshold) {
|
||||
MS_LOG(INFO) << "The server is not ready yet: cur_exchangekey_clients_num < exchange_key_threshold";
|
||||
MS_LOG(INFO) << "cur_exchangekey_clients_num : " << cur_exchange_clients_num
|
||||
<< ", exchange_key_threshold : " << cipher_init_->exchange_key_threshold;
|
||||
BuildGetKeysRsp(fbb, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
|
||||
return false;
|
||||
}
|
||||
if (cur_clients_num < cipher_init_->client_num_need_) {
|
||||
MS_LOG(INFO) << "The server is not ready yet: cur_clients_num < client_num_need";
|
||||
MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << ", client_num_need : " << cipher_init_->client_num_need_;
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
|
||||
|
||||
if (client_public_keys.find(fl_id) == client_public_keys.end()) {
|
||||
MS_LOG(INFO) << "Get keys: the fl_id: " << fl_id << "is not in exchange keys clients.";
|
||||
BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ret = cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetKeysClientList, fl_id);
|
||||
if (!ret) {
|
||||
MS_LOG(ERROR) << "update get keys clients failed";
|
||||
BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, cur_iterator, next_req_time, false);
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "GetKeys client list: ";
|
||||
for (size_t i = 0; i < clients.size(); i++) {
|
||||
MS_LOG(INFO) << "fl_id: " << clients[i];
|
||||
}
|
||||
|
||||
bool flag =
|
||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
|
||||
return flag;
|
||||
} // namespace armour
|
||||
BuildGetKeysRsp(fbb, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder) {
|
||||
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
|
||||
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
|
||||
// step 0: judge if the input param is legal.
|
||||
if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) {
|
||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
|
||||
std::string reason = "Request is nullptr or Response builder is nullptr.";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time,
|
||||
cur_iterator);
|
||||
if (exchange_keys_req == nullptr) {
|
||||
std::string reason = "Request is nullptr";
|
||||
MS_LOG(ERROR) << reason;
|
||||
BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason, next_req_time, cur_iterator);
|
||||
return false;
|
||||
}
|
||||
std::string fl_id = exchange_keys_req->fl_id()->str();
|
||||
mindspore::fl::PBMetadata device_metas =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxDeviceMetas);
|
||||
mindspore::fl::FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
|
||||
MS_LOG(INFO) << "exchange key for fl id " << fl_id;
|
||||
if (fl_id_to_meta.fl_id_to_meta().count(fl_id) == 0) {
|
||||
std::string reason = "devices_meta for " + fl_id + " is not set. Please retry later.";
|
||||
BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
// step 1: get clientlist and client keys from memory server.
|
||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
|
||||
std::vector<std::string> client_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||
|
||||
// step2: process new item data. and update new item data to memory server.
|
||||
size_t cur_clients_num = client_list.size();
|
||||
size_t cur_clients_has_keys_num = record_public_keys.size();
|
||||
size_t cur_clients_has_keys_num = client_public_keys.size();
|
||||
if (cur_clients_num != cur_clients_has_keys_num) {
|
||||
std::string reason = "client num and keys num are not equal.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
MS_LOG(ERROR) << "cur_clients_num is " << cur_clients_num << ". cur_clients_has_keys_num is "
|
||||
<< cur_clients_has_keys_num;
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, reason, next_req_time,
|
||||
cur_iterator);
|
||||
|
||||
return false;
|
||||
MS_LOG(WARNING) << reason;
|
||||
MS_LOG(WARNING) << "cur_clients_num is " << cur_clients_num << ". cur_clients_has_keys_num is "
|
||||
<< cur_clients_has_keys_num;
|
||||
}
|
||||
MS_LOG(WARNING) << "exchange_key_threshold " << cipher_init_->exchange_key_threshold << ". cur_clients_num "
|
||||
<< cur_clients_num << ". cur_clients_keys_num " << cur_clients_has_keys_num;
|
||||
|
||||
MS_LOG(INFO) << "client_num_need_ " << cipher_init_->client_num_need_ << ". cur_clients_num " << cur_clients_num;
|
||||
std::string fl_id = exchange_keys_req->fl_id()->str();
|
||||
if (cur_clients_num >= cipher_init_->client_num_need_) { // the client num is enough, return false.
|
||||
MS_LOG(ERROR) << "The server has received enough requests and refuse this request.";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
"The server has received enough requests and refuse this request.", next_req_time,
|
||||
cur_iterator);
|
||||
return false;
|
||||
}
|
||||
if (record_public_keys.find(fl_id) != record_public_keys.end()) { // the client already exists, return false.
|
||||
MS_LOG(INFO) << "The server has received the request, please do not request again.";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
if (client_public_keys.find(fl_id) != client_public_keys.end()) { // the client already exists, return false.
|
||||
MS_LOG(ERROR) << "The server has received the request, please do not request again.";
|
||||
BuildExchangeKeysRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||
"The server has received the request, please do not request again.", next_req_time,
|
||||
cur_iterator);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gets the members of the deserialized data : exchange_keys_req
|
||||
auto fbs_cpk = exchange_keys_req->c_pk();
|
||||
size_t cpk_len = fbs_cpk->size();
|
||||
auto fbs_spk = exchange_keys_req->s_pk();
|
||||
size_t spk_len = fbs_spk->size();
|
||||
|
||||
// transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
|
||||
std::vector<std::vector<unsigned char>> cur_public_key;
|
||||
std::vector<unsigned char> cpk(cpk_len);
|
||||
std::vector<unsigned char> spk(spk_len);
|
||||
bool ret_create_code_cpk = CreateArray<unsigned char>(&cpk, *fbs_cpk);
|
||||
bool ret_create_code_spk = CreateArray<unsigned char>(&spk, *fbs_spk);
|
||||
if (!(ret_create_code_cpk && ret_create_code_spk)) {
|
||||
MS_LOG(ERROR) << "create cur_public_key failed";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, "update key or client failed",
|
||||
next_req_time, cur_iterator);
|
||||
return false;
|
||||
}
|
||||
cur_public_key.push_back(cpk);
|
||||
cur_public_key.push_back(spk);
|
||||
|
||||
bool retcode_key =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, fl_id, cur_public_key);
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
|
||||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
|
||||
if (retcode_key && retcode_client) {
|
||||
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
"Success, but the server is not ready yet.", next_req_time, cur_iterator);
|
||||
BuildExchangeKeysRsp(fbb, schema::ResponseCode_SUCCEED, "Success, but the server is not ready yet.", next_req_time,
|
||||
cur_iterator);
|
||||
return true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "update key or client failed";
|
||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, "update key or client failed",
|
||||
next_req_time, cur_iterator);
|
||||
BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "update key or client failed", next_req_time,
|
||||
cur_iterator);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void CipherKeys::BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder,
|
||||
const schema::ResponseCode retcode, const std::string &reason,
|
||||
const std::string &next_req_time, const int iteration) {
|
||||
auto rsp_reason = exchange_keys_resp_builder->CreateString(reason);
|
||||
auto rsp_next_req_time = exchange_keys_resp_builder->CreateString(next_req_time);
|
||||
schema::ResponseExchangeKeysBuilder rsp_builder(*(exchange_keys_resp_builder.get()));
|
||||
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const std::string &next_req_time,
|
||||
const int iteration) {
|
||||
auto rsp_reason = fbb->CreateString(reason);
|
||||
auto rsp_next_req_time = fbb->CreateString(next_req_time);
|
||||
|
||||
schema::ResponseExchangeKeysBuilder rsp_builder(*(fbb.get()));
|
||||
rsp_builder.add_retcode(retcode);
|
||||
rsp_builder.add_reason(rsp_reason);
|
||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||
rsp_builder.add_iteration(iteration);
|
||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
||||
exchange_keys_resp_builder->Finish(rsp_exchange_keys);
|
||||
fbb->Finish(rsp_exchange_keys);
|
||||
return;
|
||||
}
|
||||
|
||||
bool CipherKeys::BuildGetKeys(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const int iteration, const std::string &next_req_time, bool is_good) {
|
||||
bool flag = true;
|
||||
if (is_good) {
|
||||
// convert client keys to standard keys list.
|
||||
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
|
||||
MS_LOG(INFO) << "Get Keys: ";
|
||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||
if (record_public_keys.size() < cipher_init_->client_num_need_) {
|
||||
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
|
||||
<< "clients num: " << cipher_init_->client_num_need_;
|
||||
flag = false;
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||
rsp_buider.add_retcode(retcode);
|
||||
rsp_buider.add_iteration(iteration);
|
||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||
auto rsp_get_keys = rsp_buider.Finish();
|
||||
|
||||
fbb->Finish(rsp_get_keys);
|
||||
} else {
|
||||
for (auto iter = record_public_keys.begin(); iter != record_public_keys.end(); ++iter) {
|
||||
// read (fl_id, c_pk, s_pk) from the map: record_public_keys_
|
||||
std::string fl_id = iter->first;
|
||||
MS_LOG(INFO) << "fl_id : " << fl_id;
|
||||
|
||||
// To serialize the members to a new Table:ClientPublicKeys
|
||||
auto fbs_fl_id = fbb->CreateString(fl_id);
|
||||
auto fbs_c_pk = fbb->CreateVector(iter->second[0].data(), iter->second[0].size());
|
||||
auto fbs_s_pk = fbb->CreateVector(iter->second[1].data(), iter->second[1].size());
|
||||
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk);
|
||||
public_keys_list.push_back(cur_public_key);
|
||||
}
|
||||
auto remote_publickeys = fbb->CreateVector(public_keys_list);
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||
rsp_buider.add_retcode(retcode);
|
||||
rsp_buider.add_iteration(iteration);
|
||||
rsp_buider.add_remote_publickeys(remote_publickeys);
|
||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||
auto rsp_get_keys = rsp_buider.Finish();
|
||||
fbb->Finish(rsp_get_keys);
|
||||
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
|
||||
}
|
||||
} else {
|
||||
void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const int iteration, const std::string &next_req_time, bool is_good) {
|
||||
if (!is_good) {
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||
rsp_buider.add_retcode(retcode);
|
||||
rsp_buider.add_iteration(iteration);
|
||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||
auto rsp_get_keys = rsp_buider.Finish();
|
||||
|
||||
fbb->Finish(rsp_get_keys);
|
||||
return;
|
||||
}
|
||||
return flag;
|
||||
const fl::PBMetadata &clients_keys_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientsKeys);
|
||||
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
|
||||
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
|
||||
std::string fl_id = iter->first;
|
||||
fl::KeysPb keys_pb = iter->second;
|
||||
auto fbs_fl_id = fbb->CreateString(fl_id);
|
||||
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
||||
auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size());
|
||||
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
|
||||
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
|
||||
auto fbs_pw_iv = fbb->CreateVector(pw_iv.data(), pw_iv.size());
|
||||
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
|
||||
auto fbs_pw_salt = fbb->CreateVector(pw_salt.data(), pw_salt.size());
|
||||
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
|
||||
public_keys_list.push_back(cur_public_key);
|
||||
}
|
||||
auto remote_publickeys = fbb->CreateVector(public_keys_list);
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||
rsp_buider.add_retcode(retcode);
|
||||
rsp_buider.add_iteration(iteration);
|
||||
rsp_buider.add_remote_publickeys(remote_publickeys);
|
||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||
auto rsp_get_keys = rsp_buider.Finish();
|
||||
fbb->Finish(rsp_get_keys);
|
||||
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
|
||||
return;
|
||||
}
|
||||
|
||||
void CipherKeys::ClearKeys() {
|
||||
|
|
|
@ -44,21 +44,19 @@ class CipherKeys {
|
|||
|
||||
// handle the client's request of get keys.
|
||||
bool GetKeys(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::GetExchangeKeys *get_exchange_keys_req,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &get_exchange_keys_resp_builder);
|
||||
const schema::GetExchangeKeys *get_exchange_keys_req, const std::shared_ptr<fl::server::FBBuilder> &fbb);
|
||||
|
||||
// handle the client's request of exchange keys.
|
||||
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder);
|
||||
const std::shared_ptr<fl::server::FBBuilder> &fbb);
|
||||
|
||||
// build response code of get keys.
|
||||
bool BuildGetKeys(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||
const int iteration, const std::string &next_req_time, bool is_good);
|
||||
void BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const int iteration, const std::string &next_req_time, bool is_good);
|
||||
// build response code of exchange keys.
|
||||
void BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder,
|
||||
const schema::ResponseCode retcode, const std::string &reason,
|
||||
const std::string &next_req_time, const int iteration);
|
||||
void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||
const std::string &reason, const std::string &next_req_time, const int iteration);
|
||||
// clear the shared memory.
|
||||
void ClearKeys();
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect
|
|||
}
|
||||
|
||||
void CipherMetaStorage::GetClientKeysFromServer(
|
||||
const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) {
|
||||
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) {
|
||||
const fl::PBMetadata &clients_keys_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||
|
@ -60,12 +60,33 @@ void CipherMetaStorage::GetClientKeysFromServer(
|
|||
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
|
||||
std::string fl_id = iter->first;
|
||||
fl::KeysPb keys_pb = iter->second;
|
||||
std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||
std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
||||
std::vector<std::vector<unsigned char>> cur_keys;
|
||||
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
||||
std::vector<std::vector<uint8_t>> cur_keys;
|
||||
cur_keys.push_back(cpk);
|
||||
cur_keys.push_back(spk);
|
||||
clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<unsigned char>>>(fl_id, cur_keys));
|
||||
clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
|
||||
}
|
||||
}
|
||||
|
||||
void CipherMetaStorage::GetClientIVsFromServer(
|
||||
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) {
|
||||
const fl::PBMetadata &clients_keys_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||
|
||||
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
|
||||
std::string fl_id = iter->first;
|
||||
fl::KeysPb keys_pb = iter->second;
|
||||
std::vector<uint8_t> ind_iv(keys_pb.ind_iv().begin(), keys_pb.ind_iv().end());
|
||||
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
|
||||
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
|
||||
|
||||
std::vector<std::vector<uint8_t>> cur_ivs;
|
||||
cur_ivs.push_back(ind_iv);
|
||||
cur_ivs.push_back(pw_iv);
|
||||
cur_ivs.push_back(pw_salt);
|
||||
clients_ivs_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,9 +94,18 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
|
|||
const fl::PBMetadata &clients_noises_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||
int count = 0;
|
||||
int count_thld = 100;
|
||||
while (clients_noises_pb.has_one_client_noises() == false) {
|
||||
MS_LOG(INFO) << "GetClientNoisesFromServer NULL.";
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
int register_time = 500;
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(register_time));
|
||||
count++;
|
||||
if (count >= count_thld) break;
|
||||
}
|
||||
MS_LOG(INFO) << "GetClientNoisesFromServer Count: " << count;
|
||||
if (clients_noises_pb.has_one_client_noises() == false) {
|
||||
MS_LOG(ERROR) << "GetClientNoisesFromServer NULL.";
|
||||
return false;
|
||||
}
|
||||
cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(),
|
||||
clients_noises_pb.one_client_noises().noise().end());
|
||||
|
@ -98,14 +128,14 @@ bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char
|
|||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
|
||||
bool retcode = true;
|
||||
fl::FLId fl_id_pb;
|
||||
fl_id_pb.set_fl_id(fl_id);
|
||||
fl::PBMetadata client_pb;
|
||||
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
|
||||
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
|
||||
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
|
||||
return retcode;
|
||||
}
|
||||
|
||||
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
|
||||
MS_LOG(INFO) << "register prime: " << prime;
|
||||
fl::Prime prime_id_pb;
|
||||
|
@ -117,9 +147,9 @@ void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &
|
|||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
||||
const std::vector<std::vector<unsigned char>> &cur_public_key) {
|
||||
bool retcode = true;
|
||||
if (cur_public_key.size() < 2) {
|
||||
const std::vector<std::vector<uint8_t>> &cur_public_key) {
|
||||
size_t correct_size = 2;
|
||||
if (cur_public_key.size() < correct_size) {
|
||||
MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
|
||||
return false;
|
||||
}
|
||||
|
@ -132,7 +162,73 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std
|
|||
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
||||
fl::PBMetadata client_and_keys_pb;
|
||||
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
||||
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||
return retcode;
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
|
||||
const schema::RequestExchangeKeys *exchange_keys_req) {
|
||||
std::string fl_id = exchange_keys_req->fl_id()->str();
|
||||
auto fbs_cpk = exchange_keys_req->c_pk();
|
||||
auto fbs_spk = exchange_keys_req->s_pk();
|
||||
if (fbs_cpk == nullptr || fbs_spk == nullptr) {
|
||||
MS_LOG(ERROR) << "public key from exchange_keys_req is null";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t spk_len = fbs_spk->size();
|
||||
size_t cpk_len = fbs_cpk->size();
|
||||
|
||||
// transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
|
||||
std::vector<std::vector<uint8_t>> cur_public_key;
|
||||
std::vector<uint8_t> cpk(cpk_len);
|
||||
std::vector<uint8_t> spk(spk_len);
|
||||
bool ret_create_code_cpk = CreateArray<uint8_t>(&cpk, *fbs_cpk);
|
||||
bool ret_create_code_spk = CreateArray<uint8_t>(&spk, *fbs_spk);
|
||||
if (!(ret_create_code_cpk && ret_create_code_spk)) {
|
||||
MS_LOG(ERROR) << "create array for public keys failed";
|
||||
return false;
|
||||
}
|
||||
cur_public_key.push_back(cpk);
|
||||
cur_public_key.push_back(spk);
|
||||
|
||||
auto fbs_ind_iv = exchange_keys_req->ind_iv();
|
||||
std::vector<char> ind_iv;
|
||||
if (fbs_ind_iv == nullptr) {
|
||||
MS_LOG(WARNING) << "ind_iv in exchange_keys_req is nullptr";
|
||||
} else {
|
||||
ind_iv.assign(fbs_ind_iv->begin(), fbs_ind_iv->end());
|
||||
}
|
||||
|
||||
auto fbs_pw_iv = exchange_keys_req->pw_iv();
|
||||
std::vector<char> pw_iv;
|
||||
if (fbs_pw_iv == nullptr) {
|
||||
MS_LOG(WARNING) << "pw_iv in exchange_keys_req is nullptr";
|
||||
} else {
|
||||
pw_iv.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
|
||||
}
|
||||
|
||||
auto fbs_pw_salt = exchange_keys_req->pw_salt();
|
||||
std::vector<char> pw_salt;
|
||||
if (fbs_pw_salt == nullptr) {
|
||||
MS_LOG(WARNING) << "pw_salt in exchange_keys_req is nullptr";
|
||||
} else {
|
||||
pw_salt.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
|
||||
}
|
||||
|
||||
// update new item to memory server.
|
||||
fl::KeysPb keys;
|
||||
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
|
||||
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
|
||||
keys.set_ind_iv(ind_iv.data(), ind_iv.size());
|
||||
keys.set_pw_iv(pw_iv.data(), pw_iv.size());
|
||||
keys.set_pw_salt(pw_salt.data(), pw_salt.size());
|
||||
fl::PairClientKeys pair_client_keys_pb;
|
||||
pair_client_keys_pb.set_fl_id(fl_id);
|
||||
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
||||
fl::PBMetadata client_and_keys_pb;
|
||||
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
||||
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||
return retcode;
|
||||
}
|
||||
|
||||
|
@ -142,13 +238,13 @@ bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const s
|
|||
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
|
||||
fl::PBMetadata client_noises_pb;
|
||||
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
|
||||
return fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
|
||||
bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientShareToServer(
|
||||
const char *list_name, const std::string &fl_id,
|
||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
|
||||
bool retcode = true;
|
||||
int size_shares = shares->size();
|
||||
fl::SharesPb shares_pb;
|
||||
for (int index = 0; index < size_shares; ++index) {
|
||||
|
@ -166,14 +262,17 @@ bool CipherMetaStorage::UpdateClientShareToServer(
|
|||
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
|
||||
fl::PBMetadata client_and_shares_pb;
|
||||
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
|
||||
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
|
||||
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
|
||||
return retcode;
|
||||
}
|
||||
|
||||
void CipherMetaStorage::RegisterClass() {
|
||||
fl::PBMetadata exchange_kyes_client_list;
|
||||
fl::PBMetadata exchange_keys_client_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
|
||||
exchange_kyes_client_list);
|
||||
exchange_keys_client_list);
|
||||
fl::PBMetadata get_keys_client_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetKeysClientList,
|
||||
get_keys_client_list);
|
||||
fl::PBMetadata clients_keys;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
|
||||
fl::PBMetadata reconstruct_client_list;
|
||||
|
@ -185,9 +284,15 @@ void CipherMetaStorage::RegisterClass() {
|
|||
fl::PBMetadata share_secretes_client_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
|
||||
share_secretes_client_list);
|
||||
fl::PBMetadata get_secretes_client_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetSecretsClientList,
|
||||
get_secretes_client_list);
|
||||
fl::PBMetadata clients_encrypt_shares;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
|
||||
clients_encrypt_shares);
|
||||
fl::PBMetadata get_update_clients_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetUpdateModelClientList,
|
||||
get_update_clients_list);
|
||||
}
|
||||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,23 +31,39 @@
|
|||
#include "fl/server/distributed_metadata_store.h"
|
||||
#include "fl/server/common.h"
|
||||
|
||||
#define IND_IV_INDEX 0
|
||||
#define PW_IV_INDEX 1
|
||||
#define PW_SALT_INDEX 2
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
||||
template <typename T1>
|
||||
bool CreateArray(std::vector<T1> *newData, const flatbuffers::Vector<T1> &fbs_arr) {
|
||||
if (newData == nullptr) return false;
|
||||
size_t size = newData->size();
|
||||
size_t size_fbs_arr = fbs_arr.size();
|
||||
if (size != size_fbs_arr) return false;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
newData->at(i) = fbs_arr.Get(i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
constexpr int SHARE_MAX_SIZE = 256;
|
||||
constexpr int SECRET_MAX_LEN_DOUBLE = 66;
|
||||
|
||||
struct clientshare_str {
|
||||
std::string fl_id;
|
||||
std::vector<unsigned char> share;
|
||||
std::vector<uint8_t> share;
|
||||
int index;
|
||||
};
|
||||
|
||||
struct CipherPublicPara {
|
||||
int t;
|
||||
int g;
|
||||
unsigned char prime[PRIME_MAX_LEN];
|
||||
unsigned char p[SECRET_MAX_LEN];
|
||||
uint8_t prime[PRIME_MAX_LEN];
|
||||
uint8_t p[SECRET_MAX_LEN];
|
||||
float dp_eps;
|
||||
float dp_delta;
|
||||
float dp_norm_clip;
|
||||
|
@ -62,7 +78,7 @@ class CipherMetaStorage {
|
|||
// Register Prime.
|
||||
void RegisterPrime(const char *list_name, const std::string &prime);
|
||||
// Get tprime from shared server.
|
||||
bool GetPrimeFromServer(const char *prime_name, unsigned char *prime);
|
||||
bool GetPrimeFromServer(const char *prime_name, uint8_t *prime);
|
||||
// Get client shares from shared server.
|
||||
void GetClientSharesFromServer(const char *list_name,
|
||||
std::map<std::string, std::vector<clientshare_str>> *clients_shares_list);
|
||||
|
@ -70,14 +86,18 @@ class CipherMetaStorage {
|
|||
void GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list);
|
||||
// Get client keys from shared server.
|
||||
void GetClientKeysFromServer(const char *list_name,
|
||||
std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list);
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list);
|
||||
void GetClientIVsFromServer(const char *list_name,
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list);
|
||||
// Get client noises from shared server.
|
||||
bool GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise);
|
||||
// Update client fl_id to shared server.
|
||||
bool UpdateClientToServer(const char *list_name, const std::string &fl_id);
|
||||
// Update client key to shared server.
|
||||
bool UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
||||
const std::vector<std::vector<unsigned char>> &cur_public_key);
|
||||
const std::vector<std::vector<uint8_t>> &cur_public_key);
|
||||
// Update client key with signature to shared server.
|
||||
bool UpdateClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req);
|
||||
// Update client noise to shared server.
|
||||
bool UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise);
|
||||
// Update client share to shared server.
|
||||
|
|
|
@ -16,33 +16,35 @@
|
|||
|
||||
#include "fl/armour/cipher/cipher_reconstruct.h"
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/armour/secure_protocol/random.h"
|
||||
#include "fl/armour/secure_protocol/masking.h"
|
||||
#include "fl/armour/secure_protocol/key_agreement.h"
|
||||
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
bool CipherReconStruct::CombineMask(
|
||||
std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_keys,
|
||||
const std::vector<std::string> &clients_share_list,
|
||||
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
|
||||
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
|
||||
const std::vector<string> &client_list) {
|
||||
bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp,
|
||||
std::map<std::string, std::vector<float>> *client_noise,
|
||||
const std::vector<std::string> &clients_share_list,
|
||||
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
|
||||
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
|
||||
const std::vector<string> &client_list,
|
||||
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs) {
|
||||
bool retcode = true;
|
||||
#ifdef _WIN32
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
retcode = false;
|
||||
#else
|
||||
if (shares_tmp == nullptr || client_noise == nullptr) {
|
||||
MS_LOG(ERROR) << "shares_tmp or client_noise is nullptr.";
|
||||
return false;
|
||||
}
|
||||
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
|
||||
// define flag_share: judge we need b or s
|
||||
bool flag_share = true;
|
||||
const std::string fl_id = iter->first;
|
||||
std::vector<std::string>::const_iterator ptr = client_list.begin();
|
||||
for (; ptr < client_list.end(); ++ptr) {
|
||||
if (*ptr == fl_id) {
|
||||
flag_share = false;
|
||||
break;
|
||||
}
|
||||
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) {
|
||||
// the client is online
|
||||
flag_share = false;
|
||||
}
|
||||
MS_LOG(INFO) << "fl_id_src : " << fl_id;
|
||||
BIGNUM *prime = BN_new();
|
||||
|
@ -61,7 +63,6 @@ bool CipherReconStruct::CombineMask(
|
|||
MS_LOG(ERROR) << "shares_tmp copy failed";
|
||||
retcode = false;
|
||||
}
|
||||
MS_LOG(INFO) << "fl_id_des : " << (iter->second)[i].fl_id;
|
||||
std::string print_share_data(reinterpret_cast<const char *>(shares_tmp->at(i)->data), shares_tmp->at(i)->len);
|
||||
}
|
||||
MS_LOG(INFO) << "end assign secrets shares to public shares ";
|
||||
|
@ -74,23 +75,42 @@ bool CipherReconStruct::CombineMask(
|
|||
MS_LOG(INFO) << "combine secrets shares Success.";
|
||||
|
||||
if (flag_share) {
|
||||
MS_LOG(INFO) << "start get complete s_uv.";
|
||||
// reconstruct pairwise noise
|
||||
MS_LOG(INFO) << "start reconstruct pairwise noise.";
|
||||
std::vector<float> noise(cipher_init_->featuremap_, 0.0);
|
||||
if (GetSuvNoise(clients_share_list, record_public_keys, fl_id, &noise, secret, length) == false)
|
||||
if (GetSuvNoise(clients_share_list, record_public_keys, client_ivs, fl_id, &noise, secret, length) == false)
|
||||
retcode = false;
|
||||
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
||||
client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
||||
MS_LOG(INFO) << " fl_id : " << fl_id;
|
||||
MS_LOG(INFO) << "end get complete s_uv.";
|
||||
} else {
|
||||
// reconstruct individual noise
|
||||
MS_LOG(INFO) << "start reconstruct individual noise.";
|
||||
std::vector<float> noise;
|
||||
if (Random::RandomAESCTR(&noise, cipher_init_->featuremap_, (const unsigned char *)secret, SECRET_MAX_LEN) < 0)
|
||||
auto it = client_ivs.find(fl_id);
|
||||
if (it == client_ivs.end()) {
|
||||
MS_LOG(ERROR) << "cannot get ivs for client: " << fl_id;
|
||||
return false;
|
||||
}
|
||||
if (it->second.size() != IV_NUM) {
|
||||
MS_LOG(ERROR) << "get " << it->second.size() << " ivs, the iv num required is: " << IV_NUM;
|
||||
return false;
|
||||
}
|
||||
std::vector<uint8_t> ind_iv = it->second[0];
|
||||
if (Masking::GetMasking(&noise, cipher_init_->featuremap_, (const uint8_t *)secret, SECRET_MAX_LEN,
|
||||
ind_iv.data(), ind_iv.size()) < 0)
|
||||
retcode = false;
|
||||
for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) {
|
||||
noise[index_noise] *= -1;
|
||||
}
|
||||
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
||||
MS_LOG(INFO) << " fl_id : " << fl_id;
|
||||
client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "reconstruct secret failed: the number of secret shares for fl_id: " << fl_id
|
||||
<< " is not enough";
|
||||
MS_LOG(ERROR) << "get " << iter->second.size()
|
||||
<< "shares, however the secrets_minnums_ required is: " << cipher_init_->secrets_minnums_;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -98,173 +118,188 @@ bool CipherReconStruct::CombineMask(
|
|||
}
|
||||
|
||||
bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &client_list) {
|
||||
// get reconstruct_secret_list_ori from memory server
|
||||
// get reconstruct_secrets from memory server
|
||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
|
||||
bool retcode = true;
|
||||
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori;
|
||||
std::map<std::string, std::vector<clientshare_str>> reconstruct_secrets;
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
||||
&reconstruct_secret_list_ori);
|
||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
||||
&reconstruct_secrets);
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> record_public_keys;
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||
std::vector<std::string> clients_reconstruct_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
|
||||
&clients_reconstruct_list);
|
||||
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> client_ivs;
|
||||
cipher_init_->cipher_meta_storage_.GetClientIVsFromServer(fl::server::kCtxClientsKeys, &client_ivs);
|
||||
|
||||
std::vector<std::string> clients_share_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||
&clients_share_list);
|
||||
if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() ||
|
||||
record_public_keys.size() < cipher_init_->client_num_need_ ||
|
||||
clients_share_list.size() < cipher_init_->share_clients_num_need_) {
|
||||
if (record_public_keys.size() < cipher_init_->exchange_key_threshold ||
|
||||
clients_share_list.size() < cipher_init_->share_secrets_threshold ||
|
||||
record_public_keys.size() != client_ivs.size()) {
|
||||
MS_LOG(ERROR) << "send share client size: " << clients_share_list.size()
|
||||
<< ", send public-key client size: " << record_public_keys.size()
|
||||
<< ", send ivs client size: " << client_ivs.size();
|
||||
MS_LOG(ERROR) << "get data from server memory failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list;
|
||||
ConvertSharesToShares(reconstruct_secret_list_ori, &reconstruct_secret_list);
|
||||
if (!ConvertSharesToShares(reconstruct_secrets, &reconstruct_secret_list)) {
|
||||
MS_LOG(ERROR) << "ConvertSharesToShares failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "recombined shares";
|
||||
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
|
||||
MS_LOG(ERROR) << "fl_id: " << iter->first;
|
||||
MS_LOG(ERROR) << "share size: " << iter->second.size();
|
||||
}
|
||||
std::vector<Share *> shares_tmp;
|
||||
if (MallocShares(&shares_tmp, cipher_init_->secrets_minnums_) == false) {
|
||||
if (!MallocShares(&shares_tmp, cipher_init_->secrets_minnums_)) {
|
||||
MS_LOG(ERROR) << "Reconstruct malloc shares_tmp invalid.";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Reconstruct client list: ";
|
||||
std::vector<std::string>::const_iterator ptr_tmp = client_list.begin();
|
||||
for (; ptr_tmp < client_list.end(); ++ptr_tmp) {
|
||||
MS_LOG(INFO) << *ptr_tmp;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Reconstruct secrets shares: ";
|
||||
std::map<std::string, std::vector<float>> client_keys;
|
||||
|
||||
retcode = CombineMask(&shares_tmp, &client_keys, clients_share_list, record_public_keys, reconstruct_secret_list,
|
||||
client_list);
|
||||
|
||||
std::map<std::string, std::vector<float>> client_noise;
|
||||
retcode = CombineMask(&shares_tmp, &client_noise, clients_share_list, record_public_keys, reconstruct_secret_list,
|
||||
client_list, client_ivs);
|
||||
DeleteShares(&shares_tmp);
|
||||
if (retcode) {
|
||||
std::vector<float> noise;
|
||||
if (GetNoiseMasksSum(&noise, client_keys) == false) {
|
||||
if (!GetNoiseMasksSum(&noise, client_noise)) {
|
||||
MS_LOG(ERROR) << " GetNoiseMasksSum failed";
|
||||
return false;
|
||||
}
|
||||
client_keys.clear();
|
||||
client_noise.clear();
|
||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
|
||||
|
||||
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise) == false)
|
||||
if (!cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise)) {
|
||||
MS_LOG(ERROR) << " ReconstructSecretsGenNoise failed. because UpdateClientNoiseToServer failed";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
|
||||
} else {
|
||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise failed. because gen noise inside failed";
|
||||
MS_LOG(ERROR) << " ReconstructSecretsGenNoise failed. because gen noise inside failed";
|
||||
}
|
||||
|
||||
return retcode;
|
||||
}
|
||||
|
||||
// reconstruct secrets
|
||||
bool CipherReconStruct::ReconstructSecrets(
|
||||
const int cur_iterator, const std::string &next_req_time, const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
|
||||
const std::vector<std::string> &client_list) {
|
||||
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &fbb,
|
||||
const std::vector<std::string> &client_list) {
|
||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
|
||||
clock_t start_time = clock();
|
||||
if (reconstruct_secret_req == nullptr || reconstruct_secret_resp_builder == nullptr) {
|
||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr. ";
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_RequestError,
|
||||
"Request is nullptr or Response builder is nullptr.", cur_iterator, next_req_time);
|
||||
return false;
|
||||
}
|
||||
if (client_list.size() < cipher_init_->reconstruct_clients_num_need_) {
|
||||
MS_LOG(ERROR) << "illegal parameters. update model client_list size: " << client_list.size();
|
||||
BuildReconstructSecretsRsp(
|
||||
reconstruct_secret_resp_builder, schema::ResponseCode_RequestError,
|
||||
"illegal parameters: update model client_list size must larger than reconstruct_clients_num_need", cur_iterator,
|
||||
next_req_time);
|
||||
return false;
|
||||
}
|
||||
std::vector<std::string> clients_reconstruct_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
|
||||
&clients_reconstruct_list);
|
||||
std::map<std::string, std::vector<clientshare_str>> clients_shares_all;
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
||||
&clients_shares_all);
|
||||
|
||||
size_t count_client_num = clients_shares_all.size();
|
||||
if (count_client_num != clients_reconstruct_list.size()) {
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
"shares client size and client size are not equal.", cur_iterator, next_req_time);
|
||||
MS_LOG(ERROR) << "shares client size and client size are not equal.";
|
||||
if (reconstruct_secret_req == nullptr) {
|
||||
std::string reason = "Request is nullptr";
|
||||
MS_LOG(ERROR) << reason;
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, cur_iterator, next_req_time);
|
||||
return false;
|
||||
}
|
||||
|
||||
int iterator = reconstruct_secret_req->iteration();
|
||||
std::string fl_id = reconstruct_secret_req->fl_id()->str();
|
||||
if (iterator != cur_iterator) {
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||
"The iteration round of the client does not match the current iteration.", cur_iterator,
|
||||
next_req_time);
|
||||
MS_LOG(ERROR) << "Client " << fl_id << " The iteration round of the client does not match the current iteration.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (find(client_list.begin(), client_list.end(), fl_id) == client_list.end()) { // client not in client list.
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
"The client is not in update model client list.", cur_iterator, next_req_time);
|
||||
MS_LOG(ERROR) << "The client " << fl_id << " is not in update model client list.";
|
||||
if (client_list.size() < cipher_init_->reconstruct_secrets_threshold) {
|
||||
MS_LOG(ERROR) << "illegal parameters. update model client_list size: " << client_list.size();
|
||||
BuildReconstructSecretsRsp(
|
||||
fbb, schema::ResponseCode_RequestError,
|
||||
"illegal parameters: update model client_list size must larger than reconstruct_clients_num_need", cur_iterator,
|
||||
next_req_time);
|
||||
return false;
|
||||
}
|
||||
if (find(clients_reconstruct_list.begin(), clients_reconstruct_list.end(), fl_id) != clients_reconstruct_list.end()) {
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
"Client has sended messages.", cur_iterator, next_req_time);
|
||||
|
||||
std::vector<std::string> get_clients_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetUpdateModelClientList,
|
||||
&get_clients_list);
|
||||
// client not in get client list.
|
||||
if (find(get_clients_list.begin(), get_clients_list.end(), fl_id) == get_clients_list.end()) {
|
||||
std::string reason;
|
||||
MS_LOG(INFO) << "The client " << fl_id << " is not in get update model client list.";
|
||||
// client in update model client list.
|
||||
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) {
|
||||
reason = "The client " + fl_id + " is not in get clients list, but in update model client list.";
|
||||
MS_LOG(INFO) << reason;
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, reason, cur_iterator, next_req_time);
|
||||
return false;
|
||||
}
|
||||
reason = "The client " + fl_id + " is not in get clients list, and not in update model client list.";
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, "The client is not in update model client list.",
|
||||
cur_iterator, next_req_time);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<clientshare_str>> reconstruct_shares;
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
||||
&reconstruct_shares);
|
||||
size_t count_client_num = reconstruct_shares.size();
|
||||
if (reconstruct_shares.find(fl_id) != reconstruct_shares.end()) {
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Client has sended messages.", cur_iterator,
|
||||
next_req_time);
|
||||
MS_LOG(INFO) << "Error, client " << fl_id << " has sended messages.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
|
||||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id);
|
||||
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
||||
fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
|
||||
if (!(retcode_client && retcode_share)) {
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
"reconstruct update shares or client failed.", cur_iterator, next_req_time);
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "reconstruct update shares or client failed.",
|
||||
cur_iterator, next_req_time);
|
||||
MS_LOG(ERROR) << "reconstruct update shares or client failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
count_client_num = count_client_num + 1;
|
||||
if (count_client_num < cipher_init_->reconstruct_clients_num_need_) {
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
"Success,but the server is not ready to reconstruct secret yet.", cur_iterator,
|
||||
if (count_client_num < cipher_init_->reconstruct_secrets_threshold) {
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||
"Success, but the server is not ready to reconstruct secret yet.", cur_iterator,
|
||||
next_req_time);
|
||||
MS_LOG(INFO) << "ReconstructSecrets" << fl_id << " Success, but count " << count_client_num << "is not enough.";
|
||||
MS_LOG(INFO) << "ReconstructSecrets " << fl_id << " Success, but count " << count_client_num << "is not enough.";
|
||||
return true;
|
||||
} else {
|
||||
bool retcode_result = true;
|
||||
const fl::PBMetadata &clients_noises_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
|
||||
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||
if (clients_noises_pb.has_one_client_noises() == false) {
|
||||
MS_LOG(INFO) << "Success,the secret will be reconstructed.";
|
||||
retcode_result = ReconstructSecretsGenNoise(client_list);
|
||||
if (retcode_result) {
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
"Success,the secret is reconstructing.", cur_iterator, next_req_time);
|
||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success, reconstruct ok.";
|
||||
} else {
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
"the secret restructs failed.", cur_iterator, next_req_time);
|
||||
MS_LOG(ERROR) << "the secret restructs failed.";
|
||||
}
|
||||
} else {
|
||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
"Clients' number is full.", cur_iterator, next_req_time);
|
||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success : no need reconstruct.";
|
||||
}
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "Reconstruct get + gennoise data time is : " << duration;
|
||||
return retcode_result;
|
||||
}
|
||||
const fl::PBMetadata &clients_noises_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
|
||||
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||
if (clients_noises_pb.has_one_client_noises() == false) {
|
||||
MS_LOG(INFO) << "Success, the secret will be reconstructed.";
|
||||
if (ReconstructSecretsGenNoise(client_list)) {
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Success,the secret is reconstructing.",
|
||||
cur_iterator, next_req_time);
|
||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success, reconstruct ok.";
|
||||
} else {
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "the secret restructs failed.", cur_iterator,
|
||||
next_req_time);
|
||||
MS_LOG(ERROR) << "CipherReconStruct::ReconstructSecrets" << fl_id << " failed.";
|
||||
}
|
||||
} else {
|
||||
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Clients' number is full.", cur_iterator,
|
||||
next_req_time);
|
||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success : no need reconstruct.";
|
||||
}
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "Reconstruct get + gennoise data time is : " << duration;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
|
||||
const std::map<std::string, std::vector<float>> &client_keys) {
|
||||
const std::map<std::string, std::vector<float>> &client_noise) {
|
||||
std::vector<float> sum(cipher_init_->featuremap_, 0.0);
|
||||
for (auto iter = client_keys.begin(); iter != client_keys.end(); iter++) {
|
||||
for (auto iter = client_noise.begin(); iter != client_noise.end(); iter++) {
|
||||
if (iter->second.size() != cipher_init_->featuremap_) {
|
||||
return false;
|
||||
}
|
||||
|
@ -301,35 +336,52 @@ void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::ser
|
|||
return;
|
||||
}
|
||||
|
||||
bool CipherReconStruct::GetSuvNoise(
|
||||
const std::vector<std::string> &clients_share_list,
|
||||
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys, const string &fl_id,
|
||||
std::vector<float> *noise, uint8_t *secret, int length) {
|
||||
bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_share_list,
|
||||
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
|
||||
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs,
|
||||
const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length) {
|
||||
for (auto p_key = clients_share_list.begin(); p_key != clients_share_list.end(); ++p_key) {
|
||||
if (*p_key != fl_id) {
|
||||
PrivateKey *privKey1 = KeyAgreement::FromPrivateBytes((unsigned char *)secret, length);
|
||||
if (privKey1 == NULL) {
|
||||
MS_LOG(ERROR) << "create privKey1 failed\n";
|
||||
PrivateKey *privKey = KeyAgreement::FromPrivateBytes(secret, length);
|
||||
if (privKey == NULL) {
|
||||
MS_LOG(ERROR) << "create privKey failed\n";
|
||||
return false;
|
||||
}
|
||||
std::vector<unsigned char> public_key = record_public_keys.at(*p_key)[1];
|
||||
PublicKey *pubKey1 = KeyAgreement::FromPublicBytes(public_key.data(), public_key.size());
|
||||
if (pubKey1 == NULL) {
|
||||
MS_LOG(ERROR) << "create pubKey1 failed\n";
|
||||
std::vector<uint8_t> public_key = record_public_keys.at(*p_key)[1];
|
||||
std::string iv_fl_id;
|
||||
if (fl_id < *p_key) {
|
||||
iv_fl_id = fl_id;
|
||||
} else {
|
||||
iv_fl_id = *p_key;
|
||||
}
|
||||
auto iter = client_ivs.find(iv_fl_id);
|
||||
if (iter == client_ivs.end()) {
|
||||
MS_LOG(ERROR) << "cannot get ivs for client: " << iv_fl_id;
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "fl_id : " << fl_id << "other id : " << *p_key;
|
||||
unsigned char secret1[SECRET_MAX_LEN] = {0};
|
||||
unsigned char salt[SECRET_MAX_LEN] = {0};
|
||||
if (KeyAgreement::ComputeSharedKey(privKey1, pubKey1, SECRET_MAX_LEN, salt, SECRET_MAX_LEN, secret1) < 0) {
|
||||
if (iter->second.size() != IV_NUM) {
|
||||
MS_LOG(ERROR) << "get " << iter->second.size() << " ivs, the iv num required is: " << IV_NUM;
|
||||
return false;
|
||||
}
|
||||
std::vector<uint8_t> pw_iv = iter->second[PW_IV_INDEX];
|
||||
std::vector<uint8_t> pw_salt = iter->second[PW_SALT_INDEX];
|
||||
PublicKey *pubKey = KeyAgreement::FromPublicBytes(public_key.data(), public_key.size());
|
||||
if (pubKey == NULL) {
|
||||
MS_LOG(ERROR) << "create pubKey failed\n";
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "private_key fl_id : " << fl_id << " public_key fl_id : " << *p_key;
|
||||
uint8_t secret1[SECRET_MAX_LEN] = {0};
|
||||
if (KeyAgreement::ComputeSharedKey(privKey, pubKey, SECRET_MAX_LEN, pw_salt.data(), pw_salt.size(), secret1) <
|
||||
0) {
|
||||
MS_LOG(ERROR) << "ComputeSharedKey failed\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<float> noise_tmp;
|
||||
if (Random::RandomAESCTR(&noise_tmp, cipher_init_->featuremap_, (const unsigned char *)secret1, SECRET_MAX_LEN) <
|
||||
0) {
|
||||
MS_LOG(ERROR) << "RandomAESCTR failed\n";
|
||||
if (Masking::GetMasking(&noise_tmp, cipher_init_->featuremap_, (const uint8_t *)secret1, SECRET_MAX_LEN,
|
||||
pw_iv.data(), pw_iv.size()) < 0) {
|
||||
MS_LOG(ERROR) << "Get Masking failed\n";
|
||||
return false;
|
||||
}
|
||||
bool symbol_noise = GetSymbol(fl_id, *p_key);
|
||||
|
@ -345,9 +397,6 @@ bool CipherReconStruct::GetSuvNoise(
|
|||
noise->at(index) += noise_tmp[index];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < 5; i++) {
|
||||
MS_LOG(INFO) << "index " << i << " : " << noise_tmp[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -361,37 +410,41 @@ bool CipherReconStruct::GetSymbol(const std::string &str1, const std::string &st
|
|||
}
|
||||
}
|
||||
|
||||
void CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
|
||||
// recombined shares by their source fl_id (ownners)
|
||||
bool CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
|
||||
std::map<std::string, std::vector<clientshare_str>> *des) {
|
||||
for (auto iter_ori = src.begin(); iter_ori != src.end(); ++iter_ori) {
|
||||
std::string fl_des = iter_ori->first;
|
||||
auto &cur_clientshare_str = iter_ori->second;
|
||||
if (des == nullptr) return false;
|
||||
for (auto iter = src.begin(); iter != src.end(); ++iter) {
|
||||
std::string des_id = iter->first;
|
||||
auto &cur_clientshare_str = iter->second;
|
||||
for (size_t index_clientshare = 0; index_clientshare < cur_clientshare_str.size(); ++index_clientshare) {
|
||||
std::string fl_src = cur_clientshare_str[index_clientshare].fl_id;
|
||||
std::string src_id = cur_clientshare_str[index_clientshare].fl_id;
|
||||
clientshare_str value;
|
||||
value.fl_id = fl_des;
|
||||
value.fl_id = des_id;
|
||||
value.share = cur_clientshare_str[index_clientshare].share;
|
||||
value.index = cur_clientshare_str[index_clientshare].index;
|
||||
if (des->find(fl_src) == des->end()) { // fl_id_des is not in reconstruct_secret_list_
|
||||
if (des->find(src_id) == des->end()) { // src_id is not in recombined shares list
|
||||
std::vector<clientshare_str> value_list;
|
||||
value_list.push_back(value);
|
||||
des->insert(std::pair<std::string, std::vector<clientshare_str>>(fl_src, value_list));
|
||||
} else { // fl_id_des is in reconstruct_secret_list_
|
||||
des->at(fl_src).push_back(value);
|
||||
des->insert(std::pair<std::string, std::vector<clientshare_str>>(src_id, value_list));
|
||||
} else {
|
||||
des->at(src_id).push_back(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int shares_size) {
|
||||
if (shares_tmp == nullptr) return false;
|
||||
for (int i = 0; i < shares_size; ++i) {
|
||||
Share *share_i = new Share;
|
||||
Share *share_i = new Share();
|
||||
if (share_i == nullptr) {
|
||||
MS_LOG(ERROR) << "shares_tmp " << i << " memory to cipher is invalid.";
|
||||
DeleteShares(shares_tmp);
|
||||
return false;
|
||||
}
|
||||
share_i->data = new unsigned char[SHARE_MAX_SIZE];
|
||||
share_i->data = new uint8_t[SHARE_MAX_SIZE];
|
||||
if (share_i->data == nullptr) {
|
||||
MS_LOG(ERROR) << "shares_tmp's data " << i << " memory to cipher is invalid.";
|
||||
DeleteShares(shares_tmp);
|
||||
|
@ -405,6 +458,7 @@ bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int share
|
|||
}
|
||||
|
||||
void CipherReconStruct::DeleteShares(std::vector<Share *> *shares_tmp) {
|
||||
if (shares_tmp == nullptr) return;
|
||||
if (shares_tmp->size() != 0) {
|
||||
for (size_t i = 0; i < shares_tmp->size(); ++i) {
|
||||
if (shares_tmp->at(i) != nullptr && shares_tmp->at(i)->data != nullptr) {
|
||||
|
|
|
@ -28,6 +28,8 @@
|
|||
#include "fl/armour/cipher/cipher_init.h"
|
||||
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||
|
||||
#define IV_NUM 3
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
// The process of reconstruct secret mask in the secure aggregation
|
||||
|
@ -44,7 +46,7 @@ class CipherReconStruct {
|
|||
// reconstruct secret mask
|
||||
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
|
||||
const std::shared_ptr<fl::server::FBBuilder> &fbb,
|
||||
const std::vector<std::string> &client_list);
|
||||
|
||||
// build response code of reconstruct secret.
|
||||
|
@ -60,26 +62,28 @@ class CipherReconStruct {
|
|||
bool GetSymbol(const std::string &str1, const std::string &str2);
|
||||
// get suv noise by computing shares result.
|
||||
bool GetSuvNoise(const std::vector<std::string> &clients_share_list,
|
||||
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
|
||||
const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length);
|
||||
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
|
||||
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs, const string &fl_id,
|
||||
std::vector<float> *noise, uint8_t *secret, int length);
|
||||
// malloc shares.
|
||||
bool MallocShares(std::vector<Share *> *shares_tmp, int shares_size);
|
||||
// delete shares.
|
||||
void DeleteShares(std::vector<Share *> *shares_tmp);
|
||||
// convert shares from receiving clients to sending clients.
|
||||
void ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
|
||||
bool ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
|
||||
std::map<std::string, std::vector<clientshare_str>> *des);
|
||||
// generate noise from shares.
|
||||
bool ReconstructSecretsGenNoise(const std::vector<string> &client_list);
|
||||
// get noise masks sum.
|
||||
bool GetNoiseMasksSum(std::vector<float> *result, const std::map<std::string, std::vector<float>> &client_keys);
|
||||
bool GetNoiseMasksSum(std::vector<float> *result, const std::map<std::string, std::vector<float>> &client_noise);
|
||||
|
||||
// combine noise mask.
|
||||
bool CombineMask(std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_keys,
|
||||
bool CombineMask(std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_noise,
|
||||
const std::vector<std::string> &clients_share_list,
|
||||
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
|
||||
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
|
||||
const std::vector<string> &client_list);
|
||||
const std::vector<string> &client_list,
|
||||
const std::map<std::string, std::vector<std::vector<unsigned char>>> &client_ivs);
|
||||
};
|
||||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,8 +25,8 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
|||
const string next_req_time) {
|
||||
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
|
||||
if (share_secrets_req == nullptr) {
|
||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
|
||||
std::string reason = "Request is nullptr or Response builder is nullptr.";
|
||||
std::string reason = "Request is nullptr";
|
||||
MS_LOG(ERROR) << reason;
|
||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time,
|
||||
cur_iterator);
|
||||
return false;
|
||||
|
@ -34,38 +34,34 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
|||
|
||||
// step 1: get client list and share secrets from memory server.
|
||||
clock_t start_time = clock();
|
||||
|
||||
int iteration = share_secrets_req->iteration();
|
||||
std::vector<std::string> get_keys_clients;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetKeysClientList, &get_keys_clients);
|
||||
std::vector<std::string> clients_share_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||
&clients_share_list);
|
||||
std::vector<std::string> clients_exchange_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList,
|
||||
&clients_exchange_list);
|
||||
|
||||
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
||||
&encrypted_shares_all);
|
||||
|
||||
MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size()
|
||||
<< "client of shares size : " << clients_share_list.size() << "shares size"
|
||||
<< encrypted_shares_all.size();
|
||||
if (encrypted_shares_all.size() != clients_share_list.size()) {
|
||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
"client of shares and shares size are not equal", next_req_time, cur_iterator);
|
||||
MS_LOG(ERROR) << "client of shares and shares size are not equal. client of shares size : "
|
||||
<< clients_share_list.size() << "shares size" << encrypted_shares_all.size();
|
||||
}
|
||||
MS_LOG(INFO) << "Client of get keys size : " << get_keys_clients.size()
|
||||
<< "client of update shares size : " << clients_share_list.size()
|
||||
<< "updated shares size: " << encrypted_shares_all.size();
|
||||
|
||||
// step 2: update new item to memory server. serialise: update pb struct to memory server.
|
||||
int iteration = share_secrets_req->iteration();
|
||||
|
||||
std::string fl_id_src = share_secrets_req->fl_id()->str();
|
||||
if (find(clients_exchange_list.begin(), clients_exchange_list.end(), fl_id_src) ==
|
||||
clients_exchange_list.end()) { // the client not in clients_exchange_list, return false.
|
||||
if (find(get_keys_clients.begin(), get_keys_clients.end(), fl_id_src) == get_keys_clients.end()) {
|
||||
// the client not in get keys clients
|
||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError,
|
||||
("client share secret is not in clients_exchange list. && client is illegal"), next_req_time,
|
||||
("client share secret is not in getkeys list. && client is illegal"), next_req_time,
|
||||
iteration);
|
||||
return false;
|
||||
}
|
||||
if (find(clients_share_list.begin(), clients_share_list.end(), fl_id_src) !=
|
||||
clients_share_list.end()) { // the client is already exists, return false.
|
||||
|
||||
if (encrypted_shares_all.find(fl_id_src) != encrypted_shares_all.end()) { // the client is already exists
|
||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED,
|
||||
("client sharesecret already exists."), next_req_time, iteration);
|
||||
return false;
|
||||
|
@ -74,24 +70,22 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
|||
// update new item to memory server.
|
||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
|
||||
(share_secrets_req->encrypted_shares());
|
||||
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
||||
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
|
||||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src);
|
||||
bool retcode = retcode_share && retcode_client;
|
||||
if (retcode) {
|
||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
|
||||
MS_LOG(INFO) << "CipherShares::ShareSecrets Success";
|
||||
} else {
|
||||
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
||||
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
|
||||
if (!(retcode_share && retcode_client)) {
|
||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime,
|
||||
"update client of shares and shares failed", next_req_time, iteration);
|
||||
MS_LOG(ERROR) << "CipherShares::ShareSecrets update client of shares and shares failed ";
|
||||
return false;
|
||||
}
|
||||
|
||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
|
||||
MS_LOG(INFO) << "CipherShares::ShareSecrets Success";
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "ShareSecrets get + deal + update data time is : " << duration;
|
||||
return retcode;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||
|
@ -101,36 +95,38 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
|||
clock_t start_time = clock();
|
||||
// step 0: check whether the parameters are legal.
|
||||
if (get_secrets_req == nullptr) {
|
||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, 0);
|
||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, nullptr);
|
||||
MS_LOG(ERROR) << "GetSecrets: get_secrets_req is nullptr.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// step 1: get client list and client shares list from memory server.
|
||||
std::vector<std::string> clients_share_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||
&clients_share_list);
|
||||
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
||||
&encrypted_shares_all);
|
||||
int iteration = get_secrets_req->iteration();
|
||||
size_t share_clients_num = clients_share_list.size();
|
||||
size_t cients_has_shares = encrypted_shares_all.size();
|
||||
if (share_clients_num != cients_has_shares) {
|
||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_OutOfTime, iteration, next_req_time, 0);
|
||||
MS_LOG(ERROR) << "cients_has_shares: " << cients_has_shares << "share_clients_num: " << share_clients_num;
|
||||
}
|
||||
if (cipher_init_->share_clients_num_need_ > share_clients_num) { // the client num is not enough, return false.
|
||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, 0);
|
||||
MS_LOG(INFO) << "GetSecrets: the client num is not enough: share_clients_num_need_: "
|
||||
<< cipher_init_->share_clients_num_need_ << "share_clients_num: " << share_clients_num;
|
||||
size_t encrypted_shares_num = encrypted_shares_all.size();
|
||||
if (cipher_init_->share_secrets_threshold > encrypted_shares_num) { // the client num is not enough, return false.
|
||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
|
||||
MS_LOG(INFO) << "GetSecrets: the encrypted shares num is not enough: share_secrets_threshold: "
|
||||
<< cipher_init_->share_secrets_threshold << "encrypted_shares_num: " << encrypted_shares_num;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string fl_id = get_secrets_req->fl_id()->str();
|
||||
if (find(clients_share_list.begin(), clients_share_list.end(), fl_id) ==
|
||||
clients_share_list.end()) { // the client is not in client list, return false.
|
||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, 0);
|
||||
MS_LOG(ERROR) << "GetSecrets: client is not in client list.";
|
||||
// the client is not in share secrets client list.
|
||||
if (encrypted_shares_all.find(fl_id) == encrypted_shares_all.end()) {
|
||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, nullptr);
|
||||
MS_LOG(ERROR) << "GetSecrets: client is not in share secrets client list.";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetSecretsClientList, fl_id);
|
||||
if (!retcode_client) {
|
||||
MS_LOG(ERROR) << "update get secrets clients failed";
|
||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
|
||||
return false;
|
||||
}
|
||||
|
||||
// get the result client shares.
|
||||
|
@ -171,7 +167,6 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
|||
|
||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SUCCEED, iteration, next_req_time,
|
||||
&encrypted_shares);
|
||||
|
||||
MS_LOG(INFO) << "CipherShares::GetSecrets Success";
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
|
@ -185,7 +180,7 @@ void CipherShares::BuildGetSecretsRsp(
|
|||
int rsp_retcode = retcode;
|
||||
int rsp_iteration = iteration;
|
||||
auto rsp_next_req_time = get_secrets_resp_builder->CreateString(next_req_time);
|
||||
if (encrypted_shares == 0) {
|
||||
if (encrypted_shares == nullptr) {
|
||||
auto get_secrets_rsp =
|
||||
schema::CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration, 0, rsp_next_req_time);
|
||||
get_secrets_resp_builder->Finish(get_secrets_rsp);
|
||||
|
@ -195,7 +190,6 @@ void CipherShares::BuildGetSecretsRsp(
|
|||
encrypted_shares_rsp, rsp_next_req_time);
|
||||
get_secrets_resp_builder->Finish(get_secrets_rsp);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -26,8 +26,8 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
|
|||
clock_t start_time = clock();
|
||||
std::vector<float> noise;
|
||||
|
||||
(void)cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
|
||||
if (noise.size() != cipher_init_->featuremap_) {
|
||||
bool ret = cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
|
||||
if (!ret || noise.size() != cipher_init_->featuremap_) {
|
||||
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -18,12 +18,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
||||
#define KEY_STEP_MAX 32
|
||||
#define KEY_STEP_MIN 16
|
||||
#define PAD_SIZE 5
|
||||
|
||||
AESEncrypt::AESEncrypt(const unsigned char *key, int key_len, unsigned char *ivec, int ivec_len, const AES_MODE mode) {
|
||||
AESEncrypt::AESEncrypt(const uint8_t *key, int key_len, const uint8_t *ivec, int ivec_len, const AES_MODE mode) {
|
||||
privKey = key;
|
||||
privKeyLen = key_len;
|
||||
iVec = ivec;
|
||||
|
@ -47,12 +42,20 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
|
|||
#else
|
||||
int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) {
|
||||
int ret;
|
||||
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) {
|
||||
MS_LOG(ERROR) << "key length must be 16 or 32!";
|
||||
if (privKey == NULL || iVec == NULL) {
|
||||
MS_LOG(ERROR) << "private key or init vector is invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (iVecLen != INIT_VEC_SIZE) {
|
||||
MS_LOG(ERROR) << "initial vector size must be 16!";
|
||||
if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
|
||||
MS_LOG(ERROR) << "key length is invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (iVecLen != AES_IV_SIZE) {
|
||||
MS_LOG(ERROR) << "initial vector size is invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (data == NULL || len <= 0 || encrypt_data == NULL) {
|
||||
MS_LOG(ERROR) << "input data is invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (aesMode == AES_CBC || aesMode == AES_CTR) {
|
||||
|
@ -69,12 +72,20 @@ int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned c
|
|||
|
||||
int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) {
|
||||
int ret = 0;
|
||||
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) {
|
||||
MS_LOG(ERROR) << "key length must be 16 or 32!";
|
||||
if (privKey == NULL || iVec == NULL) {
|
||||
MS_LOG(ERROR) << "private key or init vector is invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (iVecLen != INIT_VEC_SIZE) {
|
||||
MS_LOG(ERROR) << "initial vector size must be 16!";
|
||||
if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
|
||||
MS_LOG(ERROR) << "key length is invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (iVecLen != AES_IV_SIZE) {
|
||||
MS_LOG(ERROR) << "initial vector size is invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (data == NULL || encrypt_len <= 0 || encrypt_data == NULL) {
|
||||
MS_LOG(ERROR) << "input data is invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (aesMode == AES_CBC || aesMode == AES_CTR) {
|
||||
|
@ -88,17 +99,21 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
|
|||
return 0;
|
||||
}
|
||||
|
||||
int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const unsigned char *key, unsigned char *ivec,
|
||||
unsigned char *encrypt_data, int *encrypt_len) {
|
||||
int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
|
||||
uint8_t *encrypt_data, int *encrypt_len) {
|
||||
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
||||
if (ctx == NULL) {
|
||||
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
|
||||
return -1;
|
||||
}
|
||||
int out_len;
|
||||
int ret = 0;
|
||||
int ret;
|
||||
if (aesMode == AES_CBC) {
|
||||
switch (privKeyLen) {
|
||||
case 16:
|
||||
case KEY_LENGTH_16:
|
||||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
|
||||
break;
|
||||
case 32:
|
||||
case KEY_LENGTH_32:
|
||||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
|
||||
break;
|
||||
default:
|
||||
|
@ -107,16 +122,16 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
|
|||
}
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex CBC fail!";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
EVP_CIPHER_CTX_set_key_length(ctx, EVP_MAX_KEY_LENGTH);
|
||||
EVP_CIPHER_CTX_set_padding(ctx, PAD_SIZE);
|
||||
EVP_CIPHER_CTX_set_padding(ctx, EVP_PADDING_PKCS7);
|
||||
} else if (aesMode == AES_CTR) {
|
||||
switch (privKeyLen) {
|
||||
case 16:
|
||||
case KEY_LENGTH_16:
|
||||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
|
||||
break;
|
||||
case 32:
|
||||
case KEY_LENGTH_32:
|
||||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
|
||||
break;
|
||||
default:
|
||||
|
@ -125,21 +140,25 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
|
|||
}
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex CTR fail!";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported AES mode";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
ret = EVP_EncryptUpdate(ctx, encrypt_data, &out_len, data, len);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptUpdate fail!";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
*encrypt_len = out_len;
|
||||
ret = EVP_EncryptFinal_ex(ctx, encrypt_data + *encrypt_len, &out_len);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_EncryptFinal_ex fail!";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
*encrypt_len += out_len;
|
||||
|
@ -147,17 +166,21 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
|
|||
return 0;
|
||||
}
|
||||
|
||||
int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len, const unsigned char *key,
|
||||
unsigned char *ivec, unsigned char *decrypt_data, int *decrypt_len) {
|
||||
int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
|
||||
uint8_t *decrypt_data, int *decrypt_len) {
|
||||
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
||||
if (ctx == NULL) {
|
||||
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
|
||||
return -1;
|
||||
}
|
||||
int out_len;
|
||||
int ret = 0;
|
||||
int ret;
|
||||
if (aesMode == AES_CBC) {
|
||||
switch (privKeyLen) {
|
||||
case 16:
|
||||
case KEY_LENGTH_16:
|
||||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
|
||||
break;
|
||||
case 32:
|
||||
case KEY_LENGTH_32:
|
||||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
|
||||
break;
|
||||
default:
|
||||
|
@ -165,40 +188,46 @@ int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len
|
|||
ret = -1;
|
||||
}
|
||||
if (ret != 1) {
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
EVP_CIPHER_CTX_set_key_length(ctx, EVP_MAX_KEY_LENGTH);
|
||||
} else if (aesMode == AES_CTR) {
|
||||
switch (privKeyLen) {
|
||||
case 16:
|
||||
case KEY_LENGTH_16:
|
||||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
|
||||
break;
|
||||
case 32:
|
||||
case KEY_LENGTH_32:
|
||||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "key length is incorrect!";
|
||||
ret = -1;
|
||||
}
|
||||
if (ret != 1) {
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
ret = -1;
|
||||
}
|
||||
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "Unsupported AES mode";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
|
||||
ret = EVP_DecryptUpdate(ctx, decrypt_data, &out_len, encrypt_data, len);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptUpdate fail!";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
*decrypt_len = out_len;
|
||||
ret = EVP_DecryptFinal_ex(ctx, decrypt_data + *decrypt_len, &out_len);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "EVP_DecryptFinal_ex fail!";
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
*decrypt_len += out_len;
|
||||
EVP_CIPHER_CTX_free(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -22,7 +22,9 @@
|
|||
#endif
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#define INIT_VEC_SIZE 16
|
||||
#define AES_IV_SIZE 16
|
||||
#define KEY_LENGTH_32 32
|
||||
#define KEY_LENGTH_16 16
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
@ -35,21 +37,21 @@ class SymmetricEncrypt : Encrypt {};
|
|||
|
||||
class AESEncrypt : SymmetricEncrypt {
|
||||
public:
|
||||
AESEncrypt(const unsigned char *key, int key_len, unsigned char *ivec, int ivec_len, AES_MODE mode);
|
||||
AESEncrypt(const unsigned char *key, int key_len, const uint8_t *ivec, int ivec_len, AES_MODE mode);
|
||||
~AESEncrypt();
|
||||
int EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len);
|
||||
int DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len);
|
||||
int EncryptData(const uint8_t *data, const int len, uint8_t *encrypt_data, int *encrypt_len);
|
||||
int DecryptData(const uint8_t *encrypt_data, const int encrypt_len, uint8_t *data, int *len);
|
||||
|
||||
private:
|
||||
const unsigned char *privKey;
|
||||
const uint8_t *privKey;
|
||||
int privKeyLen;
|
||||
unsigned char *iVec;
|
||||
const uint8_t *iVec;
|
||||
int iVecLen;
|
||||
AES_MODE aesMode;
|
||||
int evp_aes_encrypt(const unsigned char *data, const int len, const unsigned char *key, unsigned char *ivec,
|
||||
unsigned char *encrypt_data, int *encrypt_len);
|
||||
int evp_aes_decrypt(const unsigned char *encrypt_data, const int len, const unsigned char *key, unsigned char *ivec,
|
||||
unsigned char *decrypt_data, int *decrypt_len);
|
||||
int evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
|
||||
uint8_t *encrypt_data, int *encrypt_len);
|
||||
int evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
|
||||
uint8_t *decrypt_data, int *decrypt_len);
|
||||
};
|
||||
|
||||
} // namespace armour
|
||||
|
|
|
@ -54,14 +54,22 @@ PrivateKey::PrivateKey(EVP_PKEY *evpKey) { evpPrivKey = evpKey; }
|
|||
|
||||
PrivateKey::~PrivateKey() { EVP_PKEY_free(evpPrivKey); }
|
||||
|
||||
int PrivateKey::GetPrivateBytes(size_t *len, unsigned char *privKeyBytes) {
|
||||
int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) {
|
||||
if (privKeyBytes == nullptr || len <= 0) {
|
||||
MS_LOG(ERROR) << "input privKeyBytes invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (!EVP_PKEY_get_raw_private_key(evpPrivKey, privKeyBytes, len)) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int PrivateKey::GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) {
|
||||
int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) {
|
||||
if (pubKeyBytes == nullptr || len <= 0) {
|
||||
MS_LOG(ERROR) << "input pubKeyBytes invalid.";
|
||||
return -1;
|
||||
}
|
||||
if (!EVP_PKEY_get_raw_public_key(evpPrivKey, pubKeyBytes, len)) {
|
||||
return -1;
|
||||
}
|
||||
|
@ -69,7 +77,19 @@ int PrivateKey::GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) {
|
|||
}
|
||||
|
||||
int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned char *salt, int salt_len,
|
||||
unsigned char *exchangeKey) {
|
||||
uint8_t *exchangeKey) {
|
||||
if (peerPublicKey == nullptr) {
|
||||
MS_LOG(ERROR) << "peerPublicKey is nullptr.";
|
||||
return -1;
|
||||
}
|
||||
if (key_len != KEY_LEN || exchangeKey == nullptr) {
|
||||
MS_LOG(ERROR) << "exchangeKey is nullptr or input key_len is incorrect.";
|
||||
return -1;
|
||||
}
|
||||
if (salt == nullptr || salt_len != SALT_LEN) {
|
||||
MS_LOG(ERROR) << "input salt in invalid.";
|
||||
return -1;
|
||||
}
|
||||
EVP_PKEY_CTX *ctx;
|
||||
size_t len = 0;
|
||||
ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL);
|
||||
|
@ -79,30 +99,38 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c
|
|||
}
|
||||
if (EVP_PKEY_derive_init(ctx) <= 0) {
|
||||
MS_LOG(ERROR) << "EVP_PKEY_derive_init failed!";
|
||||
EVP_PKEY_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
if (EVP_PKEY_derive_set_peer(ctx, peerPublicKey->evpPubKey) <= 0) {
|
||||
MS_LOG(ERROR) << "EVP_PKEY_derive_set_peer failed!";
|
||||
EVP_PKEY_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
unsigned char *secret;
|
||||
if (EVP_PKEY_derive(ctx, NULL, &len) <= 0) {
|
||||
MS_LOG(ERROR) << "get derive key size failed!";
|
||||
EVP_PKEY_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
|
||||
secret = (unsigned char *)OPENSSL_malloc(len);
|
||||
if (!secret) {
|
||||
MS_LOG(ERROR) << "malloc secret memory failed!";
|
||||
EVP_PKEY_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (EVP_PKEY_derive(ctx, secret, &len) <= 0) {
|
||||
MS_LOG(ERROR) << "derive key failed!";
|
||||
OPENSSL_free(secret);
|
||||
EVP_PKEY_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!PKCS5_PBKDF2_HMAC((char *)secret, len, salt, salt_len, ITERATION, EVP_sha256(), key_len, exchangeKey)) {
|
||||
if (!PKCS5_PBKDF2_HMAC(reinterpret_cast<char *>(secret), len, salt, salt_len, ITERATION, EVP_sha256(), key_len,
|
||||
exchangeKey)) {
|
||||
OPENSSL_free(secret);
|
||||
EVP_PKEY_CTX_free(ctx);
|
||||
return -1;
|
||||
}
|
||||
OPENSSL_free(secret);
|
||||
|
@ -118,9 +146,11 @@ PrivateKey *KeyAgreement::GeneratePrivKey() {
|
|||
return NULL;
|
||||
}
|
||||
if (EVP_PKEY_keygen_init(pctx) <= 0) {
|
||||
EVP_PKEY_CTX_free(pctx);
|
||||
return NULL;
|
||||
}
|
||||
if (EVP_PKEY_keygen(pctx, &evpKey) <= 0) {
|
||||
EVP_PKEY_CTX_free(pctx);
|
||||
return NULL;
|
||||
}
|
||||
EVP_PKEY_CTX_free(pctx);
|
||||
|
@ -131,14 +161,30 @@ PrivateKey *KeyAgreement::GeneratePrivKey() {
|
|||
PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
|
||||
unsigned char *pubKeyBytes;
|
||||
size_t len = 0;
|
||||
if (privKey == nullptr) {
|
||||
return NULL;
|
||||
}
|
||||
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, NULL, &len)) {
|
||||
return NULL;
|
||||
}
|
||||
pubKeyBytes = (unsigned char *)OPENSSL_malloc(len);
|
||||
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, pubKeyBytes, &len)) {
|
||||
pubKeyBytes = reinterpret_cast<uint8_t *>(OPENSSL_malloc(len));
|
||||
if (!pubKeyBytes) {
|
||||
MS_LOG(ERROR) << "malloc secret memory failed!";
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, pubKeyBytes, &len)) {
|
||||
MS_LOG(ERROR) << "EVP_PKEY_get_raw_public_key failed!";
|
||||
OPENSSL_free(pubKeyBytes);
|
||||
return NULL;
|
||||
}
|
||||
EVP_PKEY *evp_pubKey =
|
||||
EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, reinterpret_cast<uint8_t *>(pubKeyBytes), len);
|
||||
if (evp_pubKey == NULL) {
|
||||
MS_LOG(ERROR) << "EVP_PKEY_new_raw_public_key failed!";
|
||||
OPENSSL_free(pubKeyBytes);
|
||||
return NULL;
|
||||
}
|
||||
EVP_PKEY *evp_pubKey = EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, (unsigned char *)pubKeyBytes, len);
|
||||
OPENSSL_free(pubKeyBytes);
|
||||
PublicKey *pubKey = new PublicKey(evp_pubKey);
|
||||
return pubKey;
|
||||
|
@ -147,6 +193,7 @@ PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
|
|||
PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) {
|
||||
EVP_PKEY *evp_Key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, NULL, data, len);
|
||||
if (evp_Key == NULL) {
|
||||
MS_LOG(ERROR) << "create evp_Key from raw bytes failed!";
|
||||
return NULL;
|
||||
}
|
||||
PrivateKey *privKey = new PrivateKey(evp_Key);
|
||||
|
@ -164,7 +211,11 @@ PublicKey *KeyAgreement::FromPublicBytes(unsigned char *data, int len) {
|
|||
}
|
||||
|
||||
int KeyAgreement::ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey, int key_len,
|
||||
const unsigned char *salt, int salt_len, unsigned char *exchangeKey) {
|
||||
const unsigned char *salt, int salt_len, uint8_t *exchangeKey) {
|
||||
if (privKey == nullptr) {
|
||||
MS_LOG(ERROR) << "privKey is nullptr!";
|
||||
return -1;
|
||||
}
|
||||
return privKey->Exchange(peerPublicKey, key_len, salt, salt_len, exchangeKey);
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -24,7 +24,8 @@
|
|||
#endif
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
#define KEK_KEY_LEN 32
|
||||
#define KEY_LEN 32
|
||||
#define SALT_LEN 32
|
||||
#define ITERATION 10000
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -14,44 +14,39 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "fl/armour/secure_protocol/random.h"
|
||||
#include "fl/armour/secure_protocol/masking.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
Random::Random(size_t init_seed) { generator.seed(init_seed); }
|
||||
|
||||
Random::~Random() {}
|
||||
|
||||
#ifdef _WIN32
|
||||
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
|
||||
int Masking::GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *seed, int seed_len,
|
||||
const uint8_t *ivec, int ivec_size) {
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
return -1;
|
||||
}
|
||||
|
||||
#else
|
||||
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
|
||||
int retval = RAND_priv_bytes(secret, num_bytes);
|
||||
return retval;
|
||||
}
|
||||
|
||||
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
|
||||
if (seed_len != 16 && seed_len != 32) {
|
||||
MS_LOG(ERROR) << "seed length must be 16 or 32!";
|
||||
int Masking::GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *secret, int secret_len,
|
||||
const uint8_t *ivec, int ivec_size) {
|
||||
if ((secret_len != KEY_LENGTH_16 && secret_len != KEY_LENGTH_32) || secret == NULL) {
|
||||
MS_LOG(ERROR) << "secret is invalid!";
|
||||
return -1;
|
||||
}
|
||||
if (noise == NULL || noise_len <= 0) {
|
||||
MS_LOG(ERROR) << "noise is invalid!";
|
||||
return -1;
|
||||
}
|
||||
if (ivec == NULL || ivec_size != AES_IV_SIZE) {
|
||||
MS_LOG(ERROR) << "ivec is invalid!";
|
||||
return -1;
|
||||
}
|
||||
int size = noise_len * sizeof(int);
|
||||
std::vector<unsigned char> data(size, 0);
|
||||
std::vector<unsigned char> encrypt_data(size, 0);
|
||||
std::vector<unsigned char> ivec(INIT_VEC_SIZE, 0);
|
||||
std::vector<uint8_t> data(size, 0);
|
||||
std::vector<uint8_t> encrypt_data(size, 0);
|
||||
int encrypt_len = 0;
|
||||
AESEncrypt encrypt(seed, seed_len, ivec.data(), INIT_VEC_SIZE, AES_CTR);
|
||||
AESEncrypt encrypt(secret, secret_len, ivec, AES_IV_SIZE, AES_CTR);
|
||||
if (encrypt.EncryptData(data.data(), size, encrypt_data.data(), &encrypt_len) != 0) {
|
||||
MS_LOG(ERROR) << "call encryptData fail!";
|
||||
MS_LOG(ERROR) << "call AES-CTR failed!";
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -19,27 +19,15 @@
|
|||
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#ifndef _WIN32
|
||||
#include <openssl/rand.h>
|
||||
#endif
|
||||
#include "fl/armour/secure_protocol/encrypt.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace armour {
|
||||
|
||||
#define RANDOM_LEN 8
|
||||
|
||||
class Random {
|
||||
class Masking {
|
||||
public:
|
||||
explicit Random(size_t init_seed);
|
||||
~Random();
|
||||
// use openssl RAND_priv_bytes
|
||||
static int GetRandomBytes(unsigned char *secret, int num_bytes);
|
||||
|
||||
static int RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len);
|
||||
|
||||
private:
|
||||
std::default_random_engine generator;
|
||||
static int GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *secret, int secret_len,
|
||||
const uint8_t *ivec, int ivec_size);
|
||||
};
|
||||
} // namespace armour
|
||||
} // namespace mindspore
|
|
@ -62,6 +62,11 @@ struct RoundConfig {
|
|||
struct CipherConfig {
|
||||
float share_secrets_ratio = 1.0;
|
||||
uint64_t cipher_time_window = 300000;
|
||||
size_t exchange_keys_threshold = 0;
|
||||
size_t get_keys_threshold = 0;
|
||||
size_t share_secrets_threshold = 0;
|
||||
size_t get_secrets_threshold = 0;
|
||||
size_t client_list_threshold = 0;
|
||||
size_t reconstruct_secrets_threshold = 0;
|
||||
};
|
||||
|
||||
|
@ -207,8 +212,11 @@ constexpr auto kCtxClientNoises = "clients_noises";
|
|||
constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";
|
||||
constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares";
|
||||
constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list";
|
||||
constexpr auto kCtxGetSecretsClientList = "get_secrets_client_list";
|
||||
constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
|
||||
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
|
||||
constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list";
|
||||
constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
|
||||
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
|
||||
constexpr auto kCtxCipherPrimer = "cipher_primer";
|
||||
|
||||
|
|
|
@ -41,108 +41,102 @@ void ClientListKernel::InitKernel(size_t) {
|
|||
|
||||
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
|
||||
std::shared_ptr<server::FBBuilder> fbb) {
|
||||
bool response = false;
|
||||
std::vector<string> client_list;
|
||||
std::vector<string> empty_client_list;
|
||||
std::string fl_id = get_clients_req->fl_id()->str();
|
||||
int32_t iter_client = (size_t)get_clients_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "ClientListKernel iteration invalid. servertime is " << iter_num;
|
||||
MS_LOG(ERROR) << "ClientListKernel iteration invalid. clienttime is " << iter_client;
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
||||
uint64_t update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
||||
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
||||
client_list.push_back(client_list_pb.fl_id(i));
|
||||
}
|
||||
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) { // client in client_list.
|
||||
if (static_cast<uint64_t>(client_list_pb.fl_id_size()) >= update_model_client_num) {
|
||||
MS_LOG(INFO) << "send clients_list succeed!";
|
||||
MS_LOG(INFO) << "UpdateModel client list: ";
|
||||
for (size_t i = 0; i < client_list.size(); ++i) {
|
||||
MS_LOG(INFO) << " fl_id : " << client_list[i];
|
||||
}
|
||||
MS_LOG(INFO) << "update_model_client_num: " << update_model_client_num;
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
response = true;
|
||||
} else {
|
||||
MS_LOG(INFO) << "The server is not ready. update_model_client_need_num: " << update_model_client_num;
|
||||
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
|
||||
/*for (size_t i = 0; i < std::min(client_list.size(), size_t(2)); ++i) {
|
||||
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
|
||||
}
|
||||
for (size_t i = client_list.size() - size_t(1); i > std::max(client_list.size() - size_t(2), size_t(0));
|
||||
--i) {
|
||||
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
|
||||
}*/
|
||||
int count_tmp = 0;
|
||||
for (size_t i = 0; i < cipher_init_->get_model_num_need_; ++i) {
|
||||
size_t j = 0;
|
||||
for (; j < client_list.size(); ++j) {
|
||||
if (("f" + std::to_string(i)) == client_list[j]) break;
|
||||
}
|
||||
if (j >= client_list.size()) {
|
||||
count_tmp++;
|
||||
MS_LOG(INFO) << " no client_list fl_id : " << i;
|
||||
if (count_tmp > 3) break;
|
||||
}
|
||||
}
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
}
|
||||
}
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "update_model_client_num is zero.";
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_num is zero.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
}
|
||||
|
||||
if (!LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
||||
MS_LOG(ERROR) << "update_model_client_threshold is not set.";
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_threshold is not set.",
|
||||
empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
return false;
|
||||
}
|
||||
return response;
|
||||
uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
||||
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
||||
client_list.push_back(client_list_pb.fl_id(i));
|
||||
}
|
||||
if (static_cast<uint64_t>(client_list.size()) < update_model_client_needed) {
|
||||
MS_LOG(INFO) << "The server is not ready. update_model_client_needed: " << update_model_client_needed;
|
||||
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", empty_client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (find(client_list.begin(), client_list.end(), fl_id) == client_list.end()) { // client not in update model clients
|
||||
std::string reason = "fl_id: " + fl_id + " is not in the update_model_clients";
|
||||
MS_LOG(INFO) << reason;
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, empty_client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetUpdateModelClientList, fl_id);
|
||||
if (!retcode_client) {
|
||||
std::string reason = "update get update model clients failed";
|
||||
MS_LOG(ERROR) << reason;
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, reason, empty_client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
return false;
|
||||
}
|
||||
if (!DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str())) {
|
||||
std::string reason = "Counting for get user list request failed. Please retry later.";
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return true;
|
||||
}
|
||||
MS_LOG(INFO) << "send clients_list succeed!";
|
||||
MS_LOG(INFO) << "UpdateModel client list: ";
|
||||
for (size_t i = 0; i < client_list.size(); ++i) {
|
||||
MS_LOG(INFO) << " fl_id : " << client_list[i];
|
||||
}
|
||||
MS_LOG(INFO) << "update_model_client_needed: " << update_model_client_needed;
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
std::vector<string> client_list;
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ClientListKernel needs 1 input,but got " << inputs.size();
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel input num not match", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ClientListKernel needs 1 output,but got " << outputs.size();
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel output num not match", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "ClientListKernel num is enough", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
|
||||
|
||||
if (get_clients_req == nullptr || fbb == nullptr) {
|
||||
MS_LOG(ERROR) << "GetClientList is nullptr or ClientListRsp builder is nullptr.";
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_RequestError,
|
||||
"GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
DealClient(iter_num, get_clients_req, fbb);
|
||||
}
|
||||
}
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
std::vector<string> client_list;
|
||||
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
|
||||
int32_t iter_client = (size_t)get_clients_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "client list iteration number is invalid: server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
|
||||
}
|
||||
|
||||
DealClient(iter_num, get_clients_req, fbb);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
|
@ -164,30 +158,22 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli
|
|||
const int iteration) {
|
||||
auto rsp_reason = client_list_resp_builder->CreateString(reason);
|
||||
auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time);
|
||||
if (clients.size() > 0) {
|
||||
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
|
||||
for (auto client : clients) {
|
||||
auto client_fb = client_list_resp_builder->CreateString(client);
|
||||
clients_vector.push_back(client_fb);
|
||||
}
|
||||
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
|
||||
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
|
||||
rsp_builder.add_retcode(retcode);
|
||||
rsp_builder.add_reason(rsp_reason);
|
||||
rsp_builder.add_clients(clients_fb);
|
||||
rsp_builder.add_iteration(iteration);
|
||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
||||
client_list_resp_builder->Finish(rsp_exchange_keys);
|
||||
} else {
|
||||
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
|
||||
rsp_builder.add_retcode(retcode);
|
||||
rsp_builder.add_reason(rsp_reason);
|
||||
rsp_builder.add_iteration(iteration);
|
||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
||||
client_list_resp_builder->Finish(rsp_exchange_keys);
|
||||
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
|
||||
for (auto client : clients) {
|
||||
auto client_fb = client_list_resp_builder->CreateString(client);
|
||||
clients_vector.push_back(client_fb);
|
||||
MS_LOG(WARNING) << "update client list: ";
|
||||
MS_LOG(WARNING) << client;
|
||||
}
|
||||
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
|
||||
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
|
||||
rsp_builder.add_retcode(retcode);
|
||||
rsp_builder.add_reason(rsp_reason);
|
||||
rsp_builder.add_clients(clients_fb);
|
||||
rsp_builder.add_iteration(iteration);
|
||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
||||
client_list_resp_builder->Finish(rsp_exchange_keys);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,9 +38,36 @@ void ExchangeKeysKernel::InitKernel(size_t) {
|
|||
cipher_key_ = &armour::CipherKeys::GetInstance();
|
||||
}
|
||||
|
||||
bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num) {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
std::string reason = "Current amount for exchangeKey is enough. Please retry later.";
|
||||
cipher_key_->BuildExchangeKeysRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
|
||||
MS_LOG(WARNING) << reason;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||
const int iter_num) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(exchange_keys_req, false);
|
||||
if (!DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str())) {
|
||||
std::string reason = "Counting for exchange kernel request failed. Please retry later.";
|
||||
cipher_key_->BuildExchangeKeysRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
MS_LOG(INFO) << "Launching ExchangeKey kernel.";
|
||||
bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
|
@ -48,46 +75,49 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
<< total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 input,but got " << inputs.size();
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel input num not match",
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
void *req_data = inputs[0]->addr;
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ReachThresholdForExchangeKeys(fbb, iter_num)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
|
||||
int32_t iter_client = (size_t)exchange_keys_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "ExchangeKeys iteration number is invalid: server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 output,but got " << outputs.size();
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel output num not match",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for ExchangeKeysKernel is enough.";
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||
"Current amount for ExchangeKeysKernel is enough.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::RequestExchangeKeys *exchange_keys_req =
|
||||
flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
|
||||
int32_t iter_client = (size_t)exchange_keys_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "ExchangeKeysKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
response =
|
||||
cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
|
||||
if (!response) {
|
||||
MS_LOG(WARNING) << "update exchange keys is failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!CountForExchangeKeys(fbb, exchange_keys_req, iter_num)) {
|
||||
MS_LOG(ERROR) << "count for exchange keys failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
|
||||
if (!response) {
|
||||
MS_LOG(INFO) << "ExchangeKeysKernel response is false.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
|
@ -41,6 +43,9 @@ class ExchangeKeysKernel : public RoundKernel {
|
|||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
armour::CipherKeys *cipher_key_;
|
||||
bool ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num);
|
||||
bool CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestExchangeKeys *exchange_keys_req,
|
||||
const int iter_num);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -37,9 +37,23 @@ void GetKeysKernel::InitKernel(size_t) {
|
|||
cipher_key_ = &armour::CipherKeys::GetInstance();
|
||||
}
|
||||
|
||||
bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
|
||||
const int iter_num) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(get_keys_req, false);
|
||||
if (!DistributedCountService::GetInstance().Count(name_, get_keys_req->fl_id()->str())) {
|
||||
std::string reason = "Counting for getkeys kernel request failed. Please retry later.";
|
||||
cipher_key_->BuildGetKeysRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), false);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
MS_LOG(INFO) << "Launching GetKeys kernel.";
|
||||
bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
|
@ -47,44 +61,48 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
|||
<< total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "GetKeysKernel needs 1 input,but got " << inputs.size();
|
||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "GetKeysKernel needs 1 output,but got " << outputs.size();
|
||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
|
||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
|
||||
int32_t iter_client = (size_t)get_exchange_keys_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
} else {
|
||||
response =
|
||||
cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, get_exchange_keys_req->fl_id()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
|
||||
}
|
||||
|
||||
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
|
||||
int32_t iter_client = (size_t)get_exchange_keys_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
|
||||
if (!response) {
|
||||
MS_LOG(WARNING) << "get public keys is failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!CountForGetKeys(fbb, get_exchange_keys_req, iter_num)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
|
||||
if (!response) {
|
||||
MS_LOG(INFO) << "GetKeysKernel response is false.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
|
@ -41,6 +43,8 @@ class GetKeysKernel : public RoundKernel {
|
|||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
armour::CipherKeys *cipher_key_;
|
||||
bool CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
|
||||
const int iter_num);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -39,54 +39,72 @@ void GetSecretsKernel::InitKernel(size_t) {
|
|||
cipher_share_ = &armour::CipherShares::GetInstance();
|
||||
}
|
||||
|
||||
bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::GetShareSecrets *get_secrets_req, const int iter_num) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false);
|
||||
if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) {
|
||||
std::string reason = "Counting for get secrets kernel request failed. Please retry later.";
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), nullptr);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool response = false;
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
|
||||
<< total_duration;
|
||||
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "GetSecretsKernel needs 1 input,but got " << inputs.size();
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "GetSecretsKernel needs 1 output,but got " << outputs.size();
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
|
||||
int32_t iter_client = (size_t)get_secrets_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
|
||||
} else {
|
||||
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
|
||||
int32_t iter_client = (size_t)get_secrets_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
|
||||
}
|
||||
|
||||
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
|
||||
if (!response) {
|
||||
MS_LOG(WARNING) << "get secret shares is failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!CountForGetSecrets(fbb, get_secrets_req, iter_num)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
|
||||
if (!response) {
|
||||
MS_LOG(INFO) << "GetSecretsKernel response is false.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
|
@ -41,6 +42,8 @@ class GetSecretsKernel : public RoundKernel {
|
|||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
armour::CipherShares *cipher_share_;
|
||||
bool CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::GetShareSecrets *get_secrets_req,
|
||||
const int iter_num);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -52,7 +52,6 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
|
|||
|
||||
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
// MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
|
@ -61,71 +60,58 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
|||
<< total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
||||
"ReconstructSecretsKernel input num not match.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 output, but got " << outputs.size();
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
||||
"ReconstructSecretsKernel output num not match.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
// get client list from memory server.
|
||||
std::vector<string> update_model_clients;
|
||||
const PBMetadata update_model_clients_pb_out =
|
||||
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
|
||||
|
||||
for (int i = 0; i < update_model_clients_pb.fl_id_size(); ++i) {
|
||||
update_model_clients.push_back(update_model_clients_pb.fl_id(i));
|
||||
}
|
||||
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req =
|
||||
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
|
||||
std::string fl_id = reconstruct_secret_req->fl_id()->str();
|
||||
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
|
||||
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
|
||||
// client in get update model client list.
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
} else {
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req =
|
||||
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
|
||||
|
||||
// get client list from memory server.
|
||||
std::vector<string> client_list;
|
||||
uint64_t update_model_client_num = 0;
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
||||
update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "update_model_client_num is zero.";
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
||||
"update_model_client_num is zero.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const PBMetadata client_list_pb_out =
|
||||
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||
|
||||
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
||||
int client_list_actual_size = client_list_pb.fl_id_size();
|
||||
if (client_list_actual_size < 0) {
|
||||
client_list_actual_size = 0;
|
||||
}
|
||||
if (static_cast<uint64_t>(client_list_actual_size) < update_model_client_num) {
|
||||
MS_LOG(INFO) << "ReconstructSecretsKernel : client list is not ready " << inputs.size();
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SucNotReady,
|
||||
"ReconstructSecretsKernel : client list is not ready", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
||||
client_list.push_back(client_list_pb.fl_id(i));
|
||||
}
|
||||
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||
reconstruct_secret_req, fbb, client_list);
|
||||
if (response) {
|
||||
// MS_LOG(INFO) << "start ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
|
||||
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
|
||||
// MS_LOG(INFO) << "end ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
|
||||
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||
reconstruct_secret_req, fbb, update_model_clients);
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
|
||||
}
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough.";
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
|
@ -142,7 +128,6 @@ void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::
|
|||
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "start unmask";
|
||||
while (!Executor::GetInstance().Unmask()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
|
|
|
@ -36,58 +36,75 @@ void ShareSecretsKernel::InitKernel(size_t) {
|
|||
cipher_share_ = &armour::CipherShares::GetInstance();
|
||||
}
|
||||
|
||||
bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb,
|
||||
const schema::RequestShareSecrets *share_secrets_req,
|
||||
const int iter_num) {
|
||||
MS_ERROR_IF_NULL_W_RET_VAL(share_secrets_req, false);
|
||||
if (!DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str())) {
|
||||
std::string reason = "Counting for share secret kernel request failed. Please retry later.";
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool response = false;
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is "
|
||||
<< total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 input,but got " << inputs.size();
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError, "ShareSecretsKernel input num not match",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 output,but got " << outputs.size();
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
||||
"ShareSecretsKernel output num not match",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||
"Current amount for ShareSecretsKernel is enough.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::RequestShareSecrets *share_secrets_req =
|
||||
flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
|
||||
size_t iter_client = (size_t)share_secrets_req->iteration();
|
||||
if (iter_num != iter_client) {
|
||||
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
response =
|
||||
cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||
std::string reason = "inputs or outputs size is invalid.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
void *req_data = inputs[0]->addr;
|
||||
if (fbb == nullptr || req_data == nullptr) {
|
||||
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||
MS_LOG(ERROR) << reason;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||
"Current amount for ShareSecretsKernel is enough.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
|
||||
size_t iter_client = (size_t)share_secrets_req->iteration();
|
||||
if (iter_num != iter_client) {
|
||||
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
response = cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
if (!response) {
|
||||
MS_LOG(WARNING) << "update secret shares is failed.";
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
if (!CountForShareSecrets(fbb, share_secrets_req, iter_num)) {
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return true;
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
|
||||
if (!response) {
|
||||
MS_LOG(INFO) << "share_secrets_kernel response is false.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "fl/server/common.h"
|
||||
#include "fl/server/executor.h"
|
||||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
|
@ -41,6 +43,8 @@ class ShareSecretsKernel : public RoundKernel {
|
|||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
armour::CipherShares *cipher_share_;
|
||||
bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req,
|
||||
const int iter_num);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
|
|
|
@ -157,14 +157,31 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
|
|||
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
|
||||
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
|
||||
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||
MS_LOG(INFO) << "Update model for fl id " << update_model_fl_id;
|
||||
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
||||
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id;
|
||||
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
|
||||
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
||||
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
} else {
|
||||
std::vector<std::string> get_secrets_clients;
|
||||
#ifdef ENABLE_ARMOUR
|
||||
mindspore::armour::CipherMetaStorage cipher_meta_storage;
|
||||
cipher_meta_storage.GetClientListFromServer(fl::server::kCtxGetSecretsClientList, &get_secrets_clients);
|
||||
#endif
|
||||
if (find(get_secrets_clients.begin(), get_secrets_clients.end(), update_model_fl_id) ==
|
||||
get_secrets_clients.end()) { // the client not in get_secrets_clients
|
||||
std::string reason = "fl_id: " + update_model_fl_id + " is not in get_secrets_clients. Please retry later.";
|
||||
BuildUpdateModelRsp(
|
||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||
MS_LOG(ERROR) << reason;
|
||||
return ResultCode::kSuccessAndReturn;
|
||||
}
|
||||
}
|
||||
|
||||
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
|
||||
|
|
|
@ -25,6 +25,9 @@
|
|||
#include "fl/server/kernel/round/round_kernel.h"
|
||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||
#include "fl/server/executor.h"
|
||||
#ifdef ENABLE_ARMOUR
|
||||
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
|
|
|
@ -213,53 +213,24 @@ void Server::InitIteration() {
|
|||
#ifdef ENABLE_ARMOUR
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type == ps::kPWEncryptType) {
|
||||
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
|
||||
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
|
||||
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
|
||||
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
|
||||
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
|
||||
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold;
|
||||
cipher_exchange_keys_cnt_ = cipher_config_.exchange_keys_threshold;
|
||||
cipher_get_keys_cnt_ = cipher_config_.get_keys_threshold;
|
||||
cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
|
||||
cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
|
||||
cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
|
||||
cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
|
||||
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
|
||||
cipher_time_window_ = cipher_config_.cipher_time_window;
|
||||
|
||||
MS_LOG(INFO) << "Initializing cipher:";
|
||||
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
|
||||
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
|
||||
MS_LOG(INFO) << " cipher_exchange_keys_cnt_: " << cipher_exchange_keys_cnt_
|
||||
<< " cipher_get_keys_cnt_: " << cipher_get_keys_cnt_
|
||||
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
|
||||
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
|
||||
MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
|
||||
<< " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
|
||||
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
|
||||
<< " cipher_time_window_: " << cipher_time_window_
|
||||
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
|
||||
|
||||
std::shared_ptr<Round> exchange_keys_round =
|
||||
std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(exchange_keys_round);
|
||||
iteration_->AddRound(exchange_keys_round);
|
||||
|
||||
std::shared_ptr<Round> get_keys_round =
|
||||
std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(get_keys_round);
|
||||
iteration_->AddRound(get_keys_round);
|
||||
|
||||
std::shared_ptr<Round> share_secrets_round =
|
||||
std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(share_secrets_round);
|
||||
iteration_->AddRound(share_secrets_round);
|
||||
|
||||
std::shared_ptr<Round> get_secrets_round =
|
||||
std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(get_secrets_round);
|
||||
iteration_->AddRound(get_secrets_round);
|
||||
|
||||
std::shared_ptr<Round> get_clientlist_round =
|
||||
std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(get_clientlist_round);
|
||||
iteration_->AddRound(get_clientlist_round);
|
||||
|
||||
std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>(
|
||||
"reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_);
|
||||
MS_EXCEPTION_IF_NULL(reconstruct_secrets_round);
|
||||
iteration_->AddRound(reconstruct_secrets_round);
|
||||
MS_LOG(INFO) << "Cipher rounds has been added.";
|
||||
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
|
||||
<< " cipher_time_window_: " << cipher_time_window_;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -314,8 +285,8 @@ void Server::InitCipher() {
|
|||
param.dp_eps = dp_eps;
|
||||
param.dp_norm_clip = dp_norm_clip;
|
||||
param.encrypt_type = encrypt_type;
|
||||
cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_,
|
||||
cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
|
||||
cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
|
||||
cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
|
||||
cipher_reconstruct_secrets_up_cnt_);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ class Server {
|
|||
worker_num_(0),
|
||||
fl_server_port_(0),
|
||||
cipher_initial_client_cnt_(0),
|
||||
cipher_exchange_secrets_cnt_(0),
|
||||
cipher_exchange_keys_cnt_(0),
|
||||
cipher_share_secrets_cnt_(0),
|
||||
cipher_get_clientlist_cnt_(0),
|
||||
cipher_reconstruct_secrets_up_cnt_(0),
|
||||
|
@ -197,8 +197,10 @@ class Server {
|
|||
uint32_t worker_num_;
|
||||
uint16_t fl_server_port_;
|
||||
size_t cipher_initial_client_cnt_;
|
||||
size_t cipher_exchange_secrets_cnt_;
|
||||
size_t cipher_exchange_keys_cnt_;
|
||||
size_t cipher_get_keys_cnt_;
|
||||
size_t cipher_share_secrets_cnt_;
|
||||
size_t cipher_get_secrets_cnt_;
|
||||
size_t cipher_get_clientlist_cnt_;
|
||||
size_t cipher_reconstruct_secrets_up_cnt_;
|
||||
size_t cipher_reconstruct_secrets_down_cnt_;
|
||||
|
|
|
@ -856,9 +856,33 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
|
||||
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
|
||||
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
|
||||
size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold();
|
||||
size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold() + 1;
|
||||
|
||||
fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshold};
|
||||
size_t exchange_keys_threshold =
|
||||
std::max(static_cast<size_t>(std::ceil(start_fl_job_threshold * share_secrets_ratio)), update_model_threshold);
|
||||
size_t get_keys_threshold =
|
||||
std::max(static_cast<size_t>(std::ceil(exchange_keys_threshold * share_secrets_ratio)), update_model_threshold);
|
||||
size_t share_secrets_threshold =
|
||||
std::max(static_cast<size_t>(std::ceil(get_keys_threshold * share_secrets_ratio)), update_model_threshold);
|
||||
size_t get_secrets_threshold =
|
||||
std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold);
|
||||
size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)),
|
||||
reconstruct_secrets_threshold);
|
||||
#ifdef ENABLE_ARMOUR
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type == ps::kPWEncryptType) {
|
||||
MS_LOG(INFO) << "Add secure aggregation rounds.";
|
||||
rounds_config.push_back({"exchangeKeys", true, cipher_time_window, true, exchange_keys_threshold});
|
||||
rounds_config.push_back({"getKeys", true, cipher_time_window, true, get_keys_threshold});
|
||||
rounds_config.push_back({"shareSecrets", true, cipher_time_window, true, share_secrets_threshold});
|
||||
rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold});
|
||||
rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
|
||||
rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
|
||||
}
|
||||
#endif
|
||||
fl::server::CipherConfig cipher_config = {
|
||||
share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,
|
||||
share_secrets_threshold, get_secrets_threshold, client_list_threshold, reconstruct_secrets_threshold};
|
||||
|
||||
size_t executor_threshold = 0;
|
||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||
|
|
|
@ -126,6 +126,13 @@ message SharesPb {
|
|||
|
||||
message KeysPb {
|
||||
repeated bytes key = 1;
|
||||
string timestamp = 2;
|
||||
int32 iter_num = 3;
|
||||
bytes ind_iv = 4;
|
||||
bytes pw_iv = 5;
|
||||
bytes pw_salt = 6;
|
||||
bytes signature = 7;
|
||||
repeated string certificate_chain = 8;
|
||||
}
|
||||
|
||||
message Prime {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,12 +16,18 @@
|
|||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.I_VEC_LEN;
|
||||
import static com.mindspore.flclient.LocalFLParameter.SALT_SIZE;
|
||||
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.cipher.AESEncrypt;
|
||||
import com.mindspore.flclient.cipher.BaseUtil;
|
||||
import com.mindspore.flclient.cipher.ClientListReq;
|
||||
import com.mindspore.flclient.cipher.KEYAgreement;
|
||||
import com.mindspore.flclient.cipher.Random;
|
||||
import com.mindspore.flclient.cipher.Masking;
|
||||
import com.mindspore.flclient.cipher.ReconstructSecretReq;
|
||||
import com.mindspore.flclient.cipher.ShareSecrets;
|
||||
import com.mindspore.flclient.cipher.struct.ClientPublicKey;
|
||||
|
@ -29,6 +35,7 @@ import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
|||
import com.mindspore.flclient.cipher.struct.EncryptShare;
|
||||
import com.mindspore.flclient.cipher.struct.NewArray;
|
||||
import com.mindspore.flclient.cipher.struct.ShareSecret;
|
||||
|
||||
import mindspore.schema.ClientShare;
|
||||
import mindspore.schema.GetExchangeKeys;
|
||||
import mindspore.schema.GetShareSecrets;
|
||||
|
@ -40,22 +47,22 @@ import mindspore.schema.ResponseShareSecrets;
|
|||
import mindspore.schema.ReturnExchangeKeys;
|
||||
import mindspore.schema.ReturnShareSecrets;
|
||||
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.io.IOException;
|
||||
import java.math.BigInteger;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.spec.InvalidKeySpecException;
|
||||
import java.time.LocalDateTime;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
|
||||
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE;
|
||||
|
||||
/**
|
||||
* A class used for secure aggregation
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class CipherClient {
|
||||
private static final Logger LOGGER = Logger.getLogger(CipherClient.class.toString());
|
||||
private FLCommunication flCommunication;
|
||||
|
@ -63,129 +70,217 @@ public class CipherClient {
|
|||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private final int iteration;
|
||||
private int featureSize;
|
||||
private int t;
|
||||
private int minShareNum;
|
||||
private List<byte[]> cKey = new ArrayList<>();
|
||||
private List<byte[]> sKey = new ArrayList<>();
|
||||
private byte[] bu;
|
||||
private byte[] individualIv = new byte[I_VEC_LEN];
|
||||
private byte[] pwIVec = new byte[I_VEC_LEN];
|
||||
private byte[] pwSalt = new byte[SALT_SIZE];
|
||||
private String nextRequestTime;
|
||||
private Map<String, ClientPublicKey> clientPublicKeyList = new HashMap<String, ClientPublicKey>();
|
||||
private Map<String, byte[]> sUVKeys = new HashMap<String, byte[]>();
|
||||
private Map<String, byte[]> cUVKeys = new HashMap<String, byte[]>();
|
||||
private List<EncryptShare> clientShareList = new ArrayList<>();
|
||||
private List<EncryptShare> returnShareList = new ArrayList<>();
|
||||
private float[] featureMask;
|
||||
private List<String> u1ClientList = new ArrayList<>();
|
||||
private List<String> u2UClientList = new ArrayList<>();
|
||||
private List<String> u3ClientList = new ArrayList<>();
|
||||
private List<DecryptShareSecrets> decryptShareSecretsList = new ArrayList<>();
|
||||
private byte[] prime;
|
||||
private KEYAgreement keyAgreement = new KEYAgreement();
|
||||
private Random random = new Random();
|
||||
private Masking masking = new Masking();
|
||||
private ClientListReq clientListReq = new ClientListReq();
|
||||
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
|
||||
private int retCode;
|
||||
|
||||
/**
|
||||
* Construct function of cipherClient
|
||||
*
|
||||
* @param iter iteration number
|
||||
* @param minSecretNum minimum secret shares number used for reconstruct secret
|
||||
* @param prime prime value
|
||||
* @param featureSize featureSize of network
|
||||
*/
|
||||
public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
this.iteration = iter;
|
||||
this.featureSize = featureSize;
|
||||
this.t = minSecretNum;
|
||||
this.minShareNum = minSecretNum;
|
||||
this.prime = prime;
|
||||
this.featureMask = new float[this.featureSize];
|
||||
}
|
||||
|
||||
/**
|
||||
* Set next request time
|
||||
*
|
||||
* @param nextRequestTime next request timestamp
|
||||
*/
|
||||
public void setNextRequestTime(String nextRequestTime) {
|
||||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
public void setBU(byte[] bu) {
|
||||
this.bu = bu;
|
||||
}
|
||||
|
||||
public void setClientShareList(List<EncryptShare> clientShareList) {
|
||||
/**
|
||||
* Set client share list
|
||||
*
|
||||
* @param clientShareList client share list
|
||||
*/
|
||||
private void setClientShareList(List<EncryptShare> clientShareList) {
|
||||
this.clientShareList.clear();
|
||||
this.clientShareList = clientShareList;
|
||||
}
|
||||
|
||||
/**
|
||||
* get next request time
|
||||
*
|
||||
* @return next request time
|
||||
*/
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* get retCode
|
||||
*
|
||||
* @return retCode
|
||||
*/
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
public void genDHKeyPairs() {
|
||||
private FLClientStatus genDHKeyPairs() {
|
||||
byte[] csk = keyAgreement.generatePrivateKey();
|
||||
byte[] cpk = keyAgreement.generatePublicKey(csk);
|
||||
if (cpk == null || cpk.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[genDHKeyPairs] the return byte[] <cpk> is null, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
byte[] ssk = keyAgreement.generatePrivateKey();
|
||||
byte[] spk = keyAgreement.generatePublicKey(ssk);
|
||||
if (spk == null || spk.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[genDHKeyPairs] the return byte[] <spk> is null, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
this.cKey.clear();
|
||||
this.sKey.clear();
|
||||
this.cKey.add(cpk);
|
||||
this.cKey.add(csk);
|
||||
this.sKey.add(spk);
|
||||
this.sKey.add(ssk);
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
public void genIndividualSecret() {
|
||||
private FLClientStatus genIndividualSecret() {
|
||||
byte[] key = new byte[SEED_SIZE];
|
||||
random.getRandomBytes(key);
|
||||
setBU(key);
|
||||
int tag = masking.getRandomBytes(key);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[genIndividualSecret] the return value is -1, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
this.bu = key;
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
public List<ShareSecret> genSecretShares(byte[] secret) throws UnsupportedEncodingException {
|
||||
List<ShareSecret> shareSecretList = new ArrayList<>();
|
||||
private List<ShareSecret> genSecretShares(byte[] secret) {
|
||||
if (secret == null || secret.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[genSecretShares] the input argument <secret> is null"));
|
||||
return new ArrayList<>();
|
||||
}
|
||||
int size = u1ClientList.size();
|
||||
ShareSecrets shamir = new ShareSecrets(t, size - 1);
|
||||
ShareSecrets.SecretShare[] shares = shamir.split(secret, prime);
|
||||
int j = 0;
|
||||
for (int i = 0; i < size; i++) {
|
||||
String vFlID = u1ClientList.get(i);
|
||||
if (size <= 1) {
|
||||
LOGGER.severe(Common.addTag("[genSecretShares] the size of u1ClientList is not valid: <= 1, it should be " +
|
||||
"> 1"));
|
||||
return new ArrayList<>();
|
||||
}
|
||||
ShareSecrets shamir = new ShareSecrets(minShareNum, size - 1);
|
||||
ShareSecrets.SecretShares[] shares = shamir.split(secret, prime);
|
||||
if (shares == null || shares.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[genSecretShares] the return ShareSecrets.SecretShare[] is null, please " +
|
||||
"check!"));
|
||||
return new ArrayList<>();
|
||||
}
|
||||
int shareIndex = 0;
|
||||
List<ShareSecret> shareSecretList = new ArrayList<>();
|
||||
for (String vFlID : u1ClientList) {
|
||||
if (localFLParameter.getFlID().equals(vFlID)) {
|
||||
continue;
|
||||
} else {
|
||||
ShareSecret shareSecret = new ShareSecret();
|
||||
NewArray<byte[]> array = new NewArray<>();
|
||||
int index = shares[j].getNum();
|
||||
BigInteger intShare = shares[j].getShare();
|
||||
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
|
||||
array.setSize(share.length);
|
||||
array.setArray(share);
|
||||
shareSecret.setFlID(vFlID);
|
||||
shareSecret.setShare(array);
|
||||
shareSecret.setIndex(index);
|
||||
shareSecretList.add(shareSecret);
|
||||
j += 1;
|
||||
}
|
||||
if (shareIndex >= shares.length) {
|
||||
LOGGER.severe(Common.addTag("[genSecretShares] the shareIndex is out of range in array <shares>, " +
|
||||
"please check!"));
|
||||
return new ArrayList<>();
|
||||
}
|
||||
int index = shares[shareIndex].getNumber();
|
||||
BigInteger intShare = shares[shareIndex].getShares();
|
||||
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
|
||||
NewArray<byte[]> array = new NewArray<>();
|
||||
array.setSize(share.length);
|
||||
array.setArray(share);
|
||||
ShareSecret shareSecret = new ShareSecret();
|
||||
shareSecret.setFlID(vFlID);
|
||||
shareSecret.setShare(array);
|
||||
shareSecret.setIndex(index);
|
||||
shareSecretList.add(shareSecret);
|
||||
shareIndex += 1;
|
||||
}
|
||||
return shareSecretList;
|
||||
}
|
||||
|
||||
public void genEncryptExchangedKeys() throws InvalidKeySpecException, NoSuchAlgorithmException {
|
||||
private FLClientStatus genEncryptExchangedKeys() {
|
||||
cUVKeys.clear();
|
||||
for (String key : clientPublicKeyList.keySet()) {
|
||||
ClientPublicKey curPublicKey = clientPublicKeyList.get(key);
|
||||
String vFlID = curPublicKey.getFlID();
|
||||
if (localFLParameter.getFlID().equals(vFlID)) {
|
||||
continue;
|
||||
} else {
|
||||
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
|
||||
byte[] salt = new byte[0];
|
||||
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
||||
cUVKeys.put(vFlID, secret);
|
||||
}
|
||||
if (cKey.size() < 2) {
|
||||
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the size of cKey is not valid: < 2, it should " +
|
||||
"be >= 2, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
|
||||
if (secret1 == null || secret1.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the returned secret1 is null, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
byte[] salt = new byte[0];
|
||||
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
||||
if (secret == null || secret.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the returned secret is null, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
cUVKeys.put(vFlID, secret);
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
public void encryptShares() throws Exception {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for RequestShareSecrets **************"));
|
||||
List<EncryptShare> encryptShareList = new ArrayList<>();
|
||||
private FLClientStatus encryptShares() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for " +
|
||||
"RequestShareSecrets **************"));
|
||||
// connect sSkUv, bUV, sIndex, indexB and then Encrypt them
|
||||
if (sKey.size() < 2) {
|
||||
LOGGER.severe(Common.addTag("[encryptShares] the size of sKey is not valid: < 2, it should be >= 2, " +
|
||||
"please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
List<ShareSecret> sSkUv = genSecretShares(sKey.get(1));
|
||||
if (sSkUv.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[encryptShares] the returned List<ShareSecret> sSkUv is empty, please " +
|
||||
"check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
List<ShareSecret> bUV = genSecretShares(bu);
|
||||
if (sSkUv.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[encryptShares] the returned List<ShareSecret> bUV is empty, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (sSkUv.size() != bUV.size()) {
|
||||
LOGGER.severe(Common.addTag("[encryptShares] the sSkUv.size() should be equal to bUV.size(), please " +
|
||||
"check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
List<EncryptShare> encryptShareList = new ArrayList<>();
|
||||
for (int i = 0; i < bUV.size(); i++) {
|
||||
EncryptShare encryptShare = new EncryptShare();
|
||||
NewArray<byte[]> array = new NewArray<>();
|
||||
String vFlID = bUV.get(i).getFlID();
|
||||
byte[] sShare = sSkUv.get(i).getShare().getArray();
|
||||
byte[] bShare = bUV.get(i).getShare().getArray();
|
||||
byte[] sIndex = BaseUtil.integer2byteArray(sSkUv.get(i).getIndex());
|
||||
|
@ -200,53 +295,101 @@ public class CipherClient {
|
|||
System.arraycopy(sShare, 0, allSecret, 4 + sIndex.length + bIndex.length, sShare.length);
|
||||
System.arraycopy(bShare, 0, allSecret, 4 + sIndex.length + bIndex.length + sShare.length, bShare.length);
|
||||
// encrypt:
|
||||
byte[] iVecIn = new byte[IVEC_LEN];
|
||||
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), iVecIn, "CBC");
|
||||
String vFlID = bUV.get(i).getFlID();
|
||||
if (!cUVKeys.containsKey(vFlID)) {
|
||||
LOGGER.severe(Common.addTag("[encryptShares] the key " + vFlID + " is not in map cUVKeys, please " +
|
||||
"check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), "CBC");
|
||||
byte[] encryptData = aesEncrypt.encrypt(cUVKeys.get(vFlID), allSecret);
|
||||
if (encryptData == null || encryptData.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[encryptShares] the return byte[] is null, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
NewArray<byte[]> array = new NewArray<>();
|
||||
array.setSize(encryptData.length);
|
||||
array.setArray(encryptData);
|
||||
EncryptShare encryptShare = new EncryptShare();
|
||||
encryptShare.setFlID(vFlID);
|
||||
encryptShare.setShare(array);
|
||||
encryptShareList.add(encryptShare);
|
||||
}
|
||||
setClientShareList(encryptShareList);
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
public float[] doubleMaskingWeight() throws Exception {
|
||||
int size = u2UClientList.size();
|
||||
/**
|
||||
* get masked weight of secure aggregation
|
||||
*
|
||||
* @return masked weight
|
||||
*/
|
||||
public float[] doubleMaskingWeight() {
|
||||
List<Float> noiseBu = new ArrayList<>();
|
||||
random.randomAESCTR(noiseBu, featureSize, bu);
|
||||
int tag = masking.getMasking(noiseBu, featureSize, bu, individualIv);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the return value is -1, please check!"));
|
||||
return new float[0];
|
||||
}
|
||||
float[] mask = new float[featureSize];
|
||||
for (int i = 0; i < size; i++) {
|
||||
String vFlID = u2UClientList.get(i);
|
||||
for (String vFlID : u2UClientList) {
|
||||
if (!clientPublicKeyList.containsKey(vFlID)) {
|
||||
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the key " + vFlID + " is not in map " +
|
||||
"clientPublicKeyList, please check!"));
|
||||
return new float[0];
|
||||
}
|
||||
ClientPublicKey curPublicKey = clientPublicKeyList.get(vFlID);
|
||||
if (localFLParameter.getFlID().equals(vFlID)) {
|
||||
continue;
|
||||
}
|
||||
byte[] salt;
|
||||
byte[] iVec;
|
||||
if (vFlID.compareTo(localFLParameter.getFlID()) < 0) {
|
||||
salt = curPublicKey.getPwSalt().getArray();
|
||||
iVec = curPublicKey.getPwIv().getArray();
|
||||
} else {
|
||||
byte[] salt = new byte[0];
|
||||
byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray());
|
||||
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
||||
sUVKeys.put(vFlID, secret);
|
||||
List<Float> noiseSuv = new ArrayList<>();
|
||||
random.randomAESCTR(noiseSuv, featureSize, secret);
|
||||
int sign;
|
||||
if (localFLParameter.getFlID().compareTo(vFlID) > 0) {
|
||||
sign = 1;
|
||||
} else {
|
||||
sign = -1;
|
||||
}
|
||||
for (int j = 0; j < noiseSuv.size(); j++) {
|
||||
mask[j] = mask[j] + sign * noiseSuv.get(j);
|
||||
}
|
||||
salt = this.pwSalt;
|
||||
iVec = this.pwIVec;
|
||||
}
|
||||
if (sKey.size() < 2) {
|
||||
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the size of sKey is not valid: < 2, it should be " +
|
||||
">= 2, please check!"));
|
||||
return new float[0];
|
||||
}
|
||||
byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray());
|
||||
if (secret1 == null || secret1.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the returned secret1 is null, please check!"));
|
||||
return new float[0];
|
||||
}
|
||||
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
||||
if (secret == null || secret.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the returned secret is null, please check!"));
|
||||
return new float[0];
|
||||
}
|
||||
sUVKeys.put(vFlID, secret);
|
||||
List<Float> noiseSuv = new ArrayList<>();
|
||||
tag = masking.getMasking(noiseSuv, featureSize, secret, iVec);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the return value is -1, please check!"));
|
||||
return new float[0];
|
||||
}
|
||||
int sign;
|
||||
if (localFLParameter.getFlID().compareTo(vFlID) > 0) {
|
||||
sign = 1;
|
||||
} else {
|
||||
sign = -1;
|
||||
}
|
||||
for (int maskIndex = 0; maskIndex < noiseSuv.size(); maskIndex++) {
|
||||
mask[maskIndex] = mask[maskIndex] + sign * noiseSuv.get(maskIndex);
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < noiseBu.size(); j++) {
|
||||
mask[j] = mask[j] + noiseBu.get(j);
|
||||
for (int maskIndex = 0; maskIndex < noiseBu.size(); maskIndex++) {
|
||||
mask[maskIndex] = mask[maskIndex] + noiseBu.get(maskIndex);
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
public NewArray<byte[]> byteToArray(ByteBuffer buf, int size) {
|
||||
private NewArray<byte[]> byteToArray(ByteBuffer buf, int size) {
|
||||
NewArray<byte[]> newArray = new NewArray<>();
|
||||
newArray.setSize(size);
|
||||
byte[] array = new byte[size];
|
||||
|
@ -258,40 +401,80 @@ public class CipherClient {
|
|||
return newArray;
|
||||
}
|
||||
|
||||
public FLClientStatus requestExchangeKeys() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "=============="));
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
genDHKeyPairs();
|
||||
private FLClientStatus requestExchangeKeys() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() +
|
||||
"=============="));
|
||||
FLClientStatus status = genDHKeyPairs();
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
LOGGER.severe(Common.addTag("[requestExchangeKeys] the return status is FAILED, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (cKey.size() <= 0 || sKey.size() <= 0) {
|
||||
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of cKey or sKey is not valid: <=0."));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (cKey.size() < 2) {
|
||||
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of cKey is not valid: < 2, it should be >= 2," +
|
||||
" please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (sKey.size() < 2) {
|
||||
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of sKey is not valid: < 2, it should be >= 2," +
|
||||
" please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
byte[] cPK = cKey.get(0);
|
||||
byte[] sPK = sKey.get(0);
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
|
||||
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
|
||||
byte[] indIv = new byte[I_VEC_LEN];
|
||||
byte[] pwIv = new byte[I_VEC_LEN];
|
||||
byte[] thisPwSalt = new byte[SALT_SIZE];
|
||||
SecureRandom secureRandom = Common.getSecureRandom();
|
||||
secureRandom.nextBytes(indIv);
|
||||
secureRandom.nextBytes(pwIv);
|
||||
secureRandom.nextBytes(thisPwSalt);
|
||||
this.individualIv = indIv;
|
||||
this.pwIVec = pwIv;
|
||||
this.pwSalt = thisPwSalt;
|
||||
int indIvFbs = RequestExchangeKeys.createIndIvVector(fbBuilder, indIv);
|
||||
int pwIvFbs = RequestExchangeKeys.createPwIvVector(fbBuilder, pwIv);
|
||||
int pwSaltFbs = RequestExchangeKeys.createPwSaltVector(fbBuilder, thisPwSalt);
|
||||
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time);
|
||||
|
||||
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time
|
||||
, indIvFbs, pwIvFbs, pwSaltFbs);
|
||||
fbBuilder.finish(exchangeKeysRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg);
|
||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[requestExchangeKeys] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(responseData)) {
|
||||
LOGGER.info(Common.addTag("[requestExchangeKeys] the server is not ready now, need wait some time and" +
|
||||
" " +
|
||||
"request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer);
|
||||
FLClientStatus status = judgeRequestExchangeKeys(responseExchangeKeys);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return judgeRequestExchangeKeys(responseExchangeKeys);
|
||||
} catch (IOException ex) {
|
||||
LOGGER.severe(Common.addTag("[requestExchangeKeys] catch IOException: " + ex.getMessage()));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
|
||||
private FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
|
@ -303,7 +486,8 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request startFLJob again"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request " +
|
||||
"startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
|
@ -311,39 +495,43 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys is invalid: " + retCode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys " +
|
||||
"is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus getExchangeKeys() {
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
private FLClientStatus getExchangeKeys() {
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
int getExchangeKeysRoot = GetExchangeKeys.createGetExchangeKeys(fbBuilder, id, iteration, time);
|
||||
fbBuilder.finish(getExchangeKeysRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg);
|
||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[getExchangeKeys] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(responseData)) {
|
||||
LOGGER.info(Common.addTag("[getExchangeKeys] the server is not ready now, need wait some time and " +
|
||||
"request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer);
|
||||
FLClientStatus status = judgeGetExchangeKeys(returnExchangeKeys);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return judgeGetExchangeKeys(returnExchangeKeys);
|
||||
} catch (IOException ex) {
|
||||
LOGGER.severe(Common.addTag("[getExchangeKeys] catch IOException: " + ex.getMessage()));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
|
||||
private FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
|
@ -363,17 +551,25 @@ public class CipherClient {
|
|||
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
|
||||
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
|
||||
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
|
||||
ByteBuffer bufPwIv = bufData.remotePublickeys(i).pwIvAsByteBuffer();
|
||||
int sizePwIv = bufData.remotePublickeys(i).pwIvLength();
|
||||
ByteBuffer bufPwSalt = bufData.remotePublickeys(i).pwSaltAsByteBuffer();
|
||||
int sizePwSalt = bufData.remotePublickeys(i).pwSaltLength();
|
||||
publicKey.setCPK(byteToArray(bufCpk, sizeCpk));
|
||||
publicKey.setSPK(byteToArray(bufSpk, sizeSpk));
|
||||
publicKey.setPwIv(byteToArray(bufPwIv, sizePwIv));
|
||||
publicKey.setPwSalt(byteToArray(bufPwSalt, sizePwSalt));
|
||||
clientPublicKeyList.put(srcFlId, publicKey);
|
||||
u1ClientList.add(srcFlId);
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.SucNotReady):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetExchangeKeys again!"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
|
||||
"GetExchangeKeys again!"));
|
||||
return FLClientStatus.WAIT;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request startFLJob again"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request " +
|
||||
"startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
|
@ -381,32 +577,48 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is invalid: " + retCode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is" +
|
||||
" invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus requestShareSecrets() throws Exception {
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
genIndividualSecret();
|
||||
genEncryptExchangedKeys();
|
||||
encryptShares();
|
||||
|
||||
private FLClientStatus requestShareSecrets() {
|
||||
FLClientStatus status = genIndividualSecret();
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from genIndividualSecret" +
|
||||
"(), please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
status = genEncryptExchangedKeys();
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from " +
|
||||
"genEncryptExchangedKeys(), please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
status = encryptShares();
|
||||
if (status == FLClientStatus.FAILED) {
|
||||
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from encryptShares(), " +
|
||||
"please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
int clientShareSize = clientShareList.size();
|
||||
if (clientShareSize <= 0) {
|
||||
LOGGER.warning(Common.addTag("[PairWiseMask] encrypt shares is not ready now!"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
FLClientStatus status = requestShareSecrets();
|
||||
return status;
|
||||
return requestShareSecrets();
|
||||
} else {
|
||||
int[] add = new int[clientShareSize];
|
||||
for (int i = 0; i < clientShareSize; i++) {
|
||||
int flID = fbBuilder.createString(clientShareList.get(i).getFlID());
|
||||
int shareSecretFbs = ClientShare.createShareVector(fbBuilder, clientShareList.get(i).getShare().getArray());
|
||||
int shareSecretFbs = ClientShare.createShareVector(fbBuilder,
|
||||
clientShareList.get(i).getShare().getArray());
|
||||
ClientShare.startClientShare(fbBuilder);
|
||||
ClientShare.addFlId(fbBuilder, flID);
|
||||
ClientShare.addShare(fbBuilder, shareSecretFbs);
|
||||
|
@ -414,29 +626,33 @@ public class CipherClient {
|
|||
add[i] = clientShareRoot;
|
||||
}
|
||||
int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add);
|
||||
int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id, encryptedSharesFbs, iteration, time);
|
||||
int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id,
|
||||
encryptedSharesFbs, iteration, time);
|
||||
fbBuilder.finish(requestShareSecretsRoot);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
|
||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[requestShareSecrets] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(responseData)) {
|
||||
LOGGER.info(Common.addTag("[requestShareSecrets] the server is not ready now, need wait some time" +
|
||||
" " +
|
||||
"and request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer);
|
||||
FLClientStatus status = judgeRequestShareSecrets(responseShareSecrets);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return judgeRequestShareSecrets(responseShareSecrets);
|
||||
} catch (IOException ex) {
|
||||
LOGGER.severe(Common.addTag("[requestShareSecrets] catch IOException: " + ex.getMessage()));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
|
||||
private FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
|
@ -448,7 +664,8 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request startFLJob again"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request " +
|
||||
"startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
|
@ -456,39 +673,43 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets is invalid: " + retCode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets " +
|
||||
"is invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus getShareSecrets() {
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
private FLClientStatus getShareSecrets() {
|
||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
int time = fbBuilder.createString(dateTime);
|
||||
int getShareSecrets = GetShareSecrets.createGetShareSecrets(fbBuilder, id, iteration, time);
|
||||
fbBuilder.finish(getShareSecrets);
|
||||
byte[] msg = fbBuilder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg);
|
||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[getShareSecrets] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(responseData)) {
|
||||
LOGGER.info(Common.addTag("[getShareSecrets] the server is not ready now, need wait some time and " +
|
||||
"request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer);
|
||||
FLClientStatus status = judgeGetShareSecrets(returnShareSecrets);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return judgeGetShareSecrets(returnShareSecrets);
|
||||
} catch (IOException ex) {
|
||||
LOGGER.severe(Common.addTag("[getShareSecrets] catch IOException: " + ex.getMessage()));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
|
||||
private FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
|
@ -503,20 +724,26 @@ public class CipherClient {
|
|||
int length = bufData.encryptedSharesLength();
|
||||
for (int i = 0; i < length; i++) {
|
||||
EncryptShare shareSecret = new EncryptShare();
|
||||
shareSecret.setFlID(bufData.encryptedShares(i).flId());
|
||||
ByteBuffer bufShare = bufData.encryptedShares(i).shareAsByteBuffer();
|
||||
int sizeShare = bufData.encryptedShares(i).shareLength();
|
||||
ClientShare clientShare = bufData.encryptedShares(i);
|
||||
if (clientShare == null) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the clientShare returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
shareSecret.setFlID(clientShare.flId());
|
||||
ByteBuffer bufShare = clientShare.shareAsByteBuffer();
|
||||
int sizeShare = clientShare.shareLength();
|
||||
shareSecret.setShare(byteToArray(bufShare, sizeShare));
|
||||
returnShareList.add(shareSecret);
|
||||
u2UClientList.add(bufData.encryptedShares(i).flId());
|
||||
u2UClientList.add(clientShare.flId());
|
||||
}
|
||||
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.SucNotReady):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetShareSecrets again!"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
|
||||
"GetShareSecrets again!"));
|
||||
return FLClientStatus.WAIT;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request startFLJob again"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request " +
|
||||
"startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
|
@ -524,15 +751,22 @@ public class CipherClient {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is invalid: " + retCode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is" +
|
||||
" invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* exchangeKeys round of secure aggregation
|
||||
*
|
||||
* @return round execution result
|
||||
*/
|
||||
public FLClientStatus exchangeKeys() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys ======================"));
|
||||
FLClientStatus curStatus;
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys " +
|
||||
"======================"));
|
||||
// RequestExchangeKeys
|
||||
FLClientStatus curStatus;
|
||||
curStatus = requestExchangeKeys();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
|
@ -551,8 +785,14 @@ public class CipherClient {
|
|||
return curStatus;
|
||||
}
|
||||
|
||||
public FLClientStatus shareSecrets() throws Exception {
|
||||
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets ======================")));
|
||||
/**
|
||||
* shareSecrets round of secure aggregation
|
||||
*
|
||||
* @return round execution result
|
||||
*/
|
||||
public FLClientStatus shareSecrets() {
|
||||
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets " +
|
||||
"======================")));
|
||||
FLClientStatus curStatus;
|
||||
// RequestShareSecrets
|
||||
curStatus = requestShareSecrets();
|
||||
|
@ -573,14 +813,22 @@ public class CipherClient {
|
|||
return curStatus;
|
||||
}
|
||||
|
||||
/**
|
||||
* reconstructSecrets round of secure aggregation
|
||||
*
|
||||
* @return round execution result
|
||||
*/
|
||||
public FLClientStatus reconstructSecrets() {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret ========================"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret " +
|
||||
"========================"));
|
||||
FLClientStatus curStatus;
|
||||
// GetClientList
|
||||
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
|
||||
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList,
|
||||
cUVKeys);
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
|
||||
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList
|
||||
, cUVKeys);
|
||||
}
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
nextRequestTime = clientListReq.getNextRequestTime();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,10 +13,17 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import org.bouncycastle.crypto.BlockCipher;
|
||||
import org.bouncycastle.crypto.engines.AESEngine;
|
||||
import org.bouncycastle.crypto.prng.SP800SecureRandomBuilder;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Date;
|
||||
|
@ -26,28 +33,94 @@ import java.util.logging.Logger;
|
|||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
/**
|
||||
* Define basic global methods used in federated learning task.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class Common {
|
||||
/**
|
||||
* Global logger title.
|
||||
*/
|
||||
public static final String LOG_TITLE = "<FLClient> ";
|
||||
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
|
||||
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "albert"));
|
||||
|
||||
public static String generateUrl(boolean useHttps, boolean useElb, String ip, int port, int serverNum) {
|
||||
if (useHttps) {
|
||||
ip = "https://" + ip + ":";
|
||||
} else {
|
||||
ip = "http://" + ip + ":";
|
||||
/**
|
||||
* The list of trust flName.
|
||||
*/
|
||||
public static final List<String> FL_NAME_TRUST_LIST = new ArrayList<>(Arrays.asList("lenet", "albert"));
|
||||
|
||||
/**
|
||||
* The list of trust ssl protocol.
|
||||
*/
|
||||
public static final List<String> SSL_PROTOCOL_TRUST_LIST = new ArrayList<>(Arrays.asList("TLSv1.3", "TLSv1.2"));
|
||||
|
||||
/**
|
||||
* The tag when server is in safe mode.
|
||||
*/
|
||||
public static final String SAFE_MOD = "The cluster is in safemode.";
|
||||
|
||||
/**
|
||||
* The tag when server is not ready.
|
||||
*/
|
||||
public static final String JOB_NOT_AVAILABLE = "The server's training job is disabled or finished.";
|
||||
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
|
||||
private static SecureRandom secureRandom;
|
||||
|
||||
/**
|
||||
* Generate the URL for device-sever interaction
|
||||
*
|
||||
* @param ifUseElb whether a client randomly sends a request to a server address within a specified range.
|
||||
* @param serverNum number of servers that can send requests.
|
||||
* @param domainName the URL for device-sever interaction set by user.
|
||||
* @return the URL for device-sever interaction.
|
||||
*/
|
||||
public static String generateUrl(boolean ifUseElb, int serverNum, String domainName) {
|
||||
if (serverNum <= 0) {
|
||||
LOGGER.severe(Common.addTag("[generateUrl] the input argument <serverNum> is not valid: <= 0, it should " +
|
||||
"be > 0, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
String url;
|
||||
if (useElb) {
|
||||
if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
|
||||
LOGGER.severe(Common.addTag("[generateUrl] the input argument <domainName> is null or not valid, it " +
|
||||
"should be like as https://...... or http://...... , please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (ifUseElb) {
|
||||
if (domainName.split("//")[1].split(":").length < 2) {
|
||||
LOGGER.severe(Common.addTag("[generateUrl] the format of <domainName> is not valid, it should be like" +
|
||||
" as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
String ip = domainName.split("//")[1].split(":")[0];
|
||||
int port = Integer.parseInt(domainName.split("//")[1].split(":")[1]);
|
||||
if (!Common.checkIP(ip)) {
|
||||
LOGGER.severe(Common.addTag("[generateUrl] the <ip> split from domainName is not valid, domainName " +
|
||||
"should be like as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, " +
|
||||
"please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (!Common.checkPort(port)) {
|
||||
LOGGER.severe(Common.addTag("[generateUrl] the <port> split from domainName is not valid, domainName " +
|
||||
"should be like as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, " +
|
||||
"please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
String tag = domainName.split("//")[0] + "//";
|
||||
Random rand = new Random();
|
||||
int randomNum = rand.nextInt(100000) % serverNum + port;
|
||||
url = ip + String.valueOf(randomNum);
|
||||
url = tag + ip + ":" + String.valueOf(randomNum);
|
||||
} else {
|
||||
url = ip + String.valueOf(port);
|
||||
url = domainName;
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
/**
|
||||
* Store weight name of classifier to a list.
|
||||
*
|
||||
* @param classifierWeightName the list to store weight name of classifier.
|
||||
*/
|
||||
public static void setClassifierWeightName(List<String> classifierWeightName) {
|
||||
classifierWeightName.add("albert.pooler.weight");
|
||||
classifierWeightName.add("albert.pooler.bias");
|
||||
|
@ -56,6 +129,11 @@ public class Common {
|
|||
LOGGER.info(addTag("classifierWeightName size: " + classifierWeightName.size()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Store weight name of albert network to a list.
|
||||
*
|
||||
* @param albertWeightName the list to store weight name of albert network.
|
||||
*/
|
||||
public static void setAlbertWeightName(List<String> albertWeightName) {
|
||||
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.weight");
|
||||
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.bias");
|
||||
|
@ -78,32 +156,67 @@ public class Common {
|
|||
LOGGER.info(addTag("albertWeightName size: " + albertWeightName.size()));
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether the flName set by user is in the trust list.
|
||||
*
|
||||
* @param flName the model name set by user.
|
||||
* @return boolean value, true indicates the flName set by user is valid, false indicates the flName set by user
|
||||
* is not valid.
|
||||
*/
|
||||
public static boolean checkFLName(String flName) {
|
||||
return (flNameTrustList.contains(flName));
|
||||
return (FL_NAME_TRUST_LIST.contains(flName));
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether the sslProtocol set by user is in the trust list.
|
||||
*
|
||||
* @param sslProtocol the ssl protocol set by user.
|
||||
* @return boolean value, true indicates the sslProtocol set by user is valid, false indicates the sslProtocol
|
||||
* set by user is not valid.
|
||||
*/
|
||||
public static boolean checkSSLProtocol(String sslProtocol) {
|
||||
return (SSL_PROTOCOL_TRUST_LIST.contains(sslProtocol));
|
||||
}
|
||||
|
||||
/**
|
||||
* The program waits for the specified time and then to continue.
|
||||
*
|
||||
* @param millis the waiting time (ms).
|
||||
*/
|
||||
public static void sleep(long millis) {
|
||||
try {
|
||||
Thread.sleep(millis); //1000 milliseconds is one second.
|
||||
Thread.sleep(millis); // 1000 milliseconds is one second.
|
||||
} catch (InterruptedException ex) {
|
||||
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the waiting time for repeated requests.
|
||||
*
|
||||
* @param nextRequestTime the timestamp return from server.
|
||||
* @return the waiting time for repeated requests.
|
||||
*/
|
||||
public static long getWaitTime(String nextRequestTime) {
|
||||
|
||||
Date date = new Date();
|
||||
long currentTime = date.getTime();
|
||||
long waitTime = 0;
|
||||
if (!("").equals(nextRequestTime)) {
|
||||
long waitTime = 0L;
|
||||
if (!(nextRequestTime == null || nextRequestTime.isEmpty())) {
|
||||
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
|
||||
}
|
||||
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime));
|
||||
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " +
|
||||
currentTime));
|
||||
LOGGER.info(addTag("[getWaitTime] waitTime: " + waitTime));
|
||||
return waitTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get start time.
|
||||
*
|
||||
* @param tag the tag added to the logger.
|
||||
* @return start time.
|
||||
*/
|
||||
public static long startTime(String tag) {
|
||||
Date startDate = new Date();
|
||||
long startTime = startDate.getTime();
|
||||
|
@ -111,6 +224,12 @@ public class Common {
|
|||
return startTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get end time.
|
||||
*
|
||||
* @param start the start time.
|
||||
* @param tag the tag added to the logger.
|
||||
*/
|
||||
public static void endTime(long start, String tag) {
|
||||
Date endDate = new Date();
|
||||
long endTime = endDate.getTime();
|
||||
|
@ -118,53 +237,182 @@ public class Common {
|
|||
LOGGER.info(addTag("[interval time] <" + tag + "> interval time(ms): " + (endTime - start)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Add specified tag to the message.
|
||||
*
|
||||
* @param message the message need to add tag.
|
||||
* @return the message after adding tag.
|
||||
*/
|
||||
public static String addTag(String message) {
|
||||
return LOG_TITLE + message;
|
||||
}
|
||||
|
||||
public static boolean isSafeMod(byte[] message, String safeModTag) {
|
||||
return (new String(message)).contains(safeModTag);
|
||||
/**
|
||||
* Check whether the server is ready based on the message returned by the server.
|
||||
*
|
||||
* @param message the message returned by the server..
|
||||
* @return boolean value, true indicates the server is ready, false indicates the server is not ready.
|
||||
*/
|
||||
public static boolean isSeverReady(byte[] message) {
|
||||
if (message == null) {
|
||||
LOGGER.severe(Common.addTag("[isSeverReady] the input argument <message> is null, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
String messageStr = new String(message);
|
||||
if (messageStr.contains(SAFE_MOD)) {
|
||||
LOGGER.info(Common.addTag("[isSeverReady] " + SAFE_MOD + ", need wait some time and request again"));
|
||||
return false;
|
||||
} else if (messageStr.contains(JOB_NOT_AVAILABLE)) {
|
||||
LOGGER.info(Common.addTag("[isSeverReady] " + JOB_NOT_AVAILABLE + ", need wait some time and request " +
|
||||
"again"));
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
public static String getRealPath (String path) {
|
||||
LOGGER.info(addTag("[original path] " + path));
|
||||
/**
|
||||
* Convert a user-set path to a standard path.
|
||||
*
|
||||
* @param path the user-set path.
|
||||
* @return the standard path.
|
||||
*/
|
||||
public static String getRealPath(String path) {
|
||||
if (path == null) {
|
||||
LOGGER.severe(Common.addTag("[getRealPath] the input argument <path> is null, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
LOGGER.info(addTag("[getRealPath] original path: " + path));
|
||||
String[] paths = path.split(",");
|
||||
for (int i = 0; i < paths.length; i++) {
|
||||
LOGGER.info(addTag("[original path " + i + "] " + paths[i]));
|
||||
if (paths[i] == null) {
|
||||
LOGGER.severe(Common.addTag("[getRealPath] the paths[i] is null, please check"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
LOGGER.info(addTag("[getRealPath] original path " + i + ": " + paths[i]));
|
||||
File file = new File(paths[i]);
|
||||
try {
|
||||
paths[i] = file.getCanonicalPath();
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(addTag("[checkPath] catch IOException in file.getCanonicalPath(): " + e.getMessage()));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(addTag("[getRealPath] catch IOException in file.getCanonicalPath(): " + e.getMessage()));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
path = String.join(",", Arrays.asList(paths));
|
||||
LOGGER.info(addTag("[real path] " + path));
|
||||
return path;
|
||||
String realPath = String.join(",", Arrays.asList(paths));
|
||||
LOGGER.info(addTag("[getRealPath] real path: " + realPath));
|
||||
return realPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether the path set by user exists.
|
||||
*
|
||||
* @param path the path set by user.
|
||||
* @return boolean value, true indicates the path is exist, false indicates the path is not exist
|
||||
*/
|
||||
public static boolean checkPath(String path) {
|
||||
boolean tag = true;
|
||||
if (path == null) {
|
||||
LOGGER.severe(Common.addTag("[checkPath] the input argument <path> is null, please check!"));
|
||||
return false;
|
||||
}
|
||||
String[] paths = path.split(",");
|
||||
for (int i = 0; i < paths.length; i++) {
|
||||
if (paths[i] == null) {
|
||||
LOGGER.severe(Common.addTag("[checkPath] the paths[i] is null, please check"));
|
||||
return false;
|
||||
}
|
||||
LOGGER.info(addTag("[check path " + i + "] " + paths[i]));
|
||||
File file = new File(paths[i]);
|
||||
if (!file.exists()) {
|
||||
tag = false;
|
||||
LOGGER.severe(Common.addTag("[checkPath] the path is not exist, please check"));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return tag;
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether the ip set by user is valid.
|
||||
*
|
||||
* @param ip the ip set by user.
|
||||
* @return boolean value, true indicates the ip is valid, false indicates the ip is not valid.
|
||||
*/
|
||||
public static boolean checkIP(String ip) {
|
||||
String regex = "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])";
|
||||
if (ip == null) {
|
||||
LOGGER.severe(Common.addTag("[checkIP] the input argument <ip> is null, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
String regex = "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])[.]" +
|
||||
"(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.]" +
|
||||
"(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.]" +
|
||||
"(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])";
|
||||
Pattern pattern = Pattern.compile(regex);
|
||||
Matcher matcher = pattern.matcher(ip);
|
||||
return matcher.matches();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether the port set by user is valid.
|
||||
*
|
||||
* @param port the port set by user.
|
||||
* @return boolean value, true indicates the port is valid, false indicates the port is not valid.
|
||||
*/
|
||||
public static boolean checkPort(int port) {
|
||||
return port > 0 && port <= 65535;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain secure random.
|
||||
*
|
||||
* @return the secure random.
|
||||
*/
|
||||
public static SecureRandom getSecureRandom() {
|
||||
if (secureRandom == null) {
|
||||
LOGGER.severe(Common.addTag("[setSecureRandom] the parameter secureRandom is null, please set it before " +
|
||||
"use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return secureRandom;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the secure random to parameter secureRandom of the class Common.
|
||||
*
|
||||
* @param secureRandom the secure random.
|
||||
*/
|
||||
public static void setSecureRandom(SecureRandom secureRandom) {
|
||||
if (secureRandom == null) {
|
||||
LOGGER.severe(Common.addTag("[setSecureRandom] the input parameter secureRandom is null, please check"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
Common.secureRandom = secureRandom;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain fast secure random.
|
||||
*
|
||||
* @return the fast secure random.
|
||||
*/
|
||||
public static SecureRandom getFastSecureRandom() {
|
||||
try {
|
||||
LOGGER.info(Common.addTag("[getFastSecureRandom] start create fastSecureRandom"));
|
||||
long start = System.currentTimeMillis();
|
||||
SecureRandom blockingRandom = SecureRandom.getInstanceStrong();
|
||||
boolean ifPredictionResistant = true;
|
||||
BlockCipher cipher = new AESEngine();
|
||||
int cipherLen = 256;
|
||||
int entropyBitsRequired = 384;
|
||||
byte[] nonce = null;
|
||||
boolean ifForceReseed = false;
|
||||
SecureRandom fastRandom = new SP800SecureRandomBuilder(blockingRandom, ifPredictionResistant)
|
||||
.setEntropyBitsRequired(entropyBitsRequired)
|
||||
.buildCTR(cipher, cipherLen, nonce, ifForceReseed);
|
||||
fastRandom.nextInt();
|
||||
LOGGER.info(Common.addTag("[getFastSecureRandom] finish create fastSecureRandom"));
|
||||
LOGGER.info(Common.addTag("[getFastSecureRandom] cost time: " + (System.currentTimeMillis() - start)));
|
||||
return fastRandom;
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException: " + e.getMessage()));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,13 +13,17 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
/**
|
||||
* The early stop mod.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public enum EarlyStopMod {
|
||||
LOSS_DIFF,
|
||||
LOSS_ABS,
|
||||
WEIGHT_DIFF,
|
||||
NOT_EARLY_STOP
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,12 +13,16 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
/**
|
||||
* Security encryption level.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public enum EncryptLevel {
|
||||
PW_ENCRYPT,
|
||||
DP_ENCRYPT,
|
||||
NOT_ENCRYPT
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,21 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
/**
|
||||
* The status code of federated learning.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public enum FLClientStatus {
|
||||
SUCCESS,
|
||||
FAILED,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,8 @@
|
|||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.TIME_OUT;
|
||||
|
||||
import okhttp3.Call;
|
||||
import okhttp3.Callback;
|
||||
import okhttp3.MediaType;
|
||||
|
@ -24,12 +26,6 @@ import okhttp3.Request;
|
|||
import okhttp3.RequestBody;
|
||||
import okhttp3.Response;
|
||||
|
||||
import javax.net.ssl.HostnameVerifier;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSession;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import java.io.IOException;
|
||||
import java.security.KeyManagementException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
|
@ -39,34 +35,40 @@ import java.util.concurrent.TimeUnit;
|
|||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.TIME_OUT;
|
||||
import javax.net.ssl.HostnameVerifier;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSession;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
|
||||
/**
|
||||
* Define the communication interface.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class FLCommunication implements IFLCommunication {
|
||||
private static int timeOut;
|
||||
private static boolean ssl = false;
|
||||
private static String env;
|
||||
private static SSLSocketFactory sslSocketFactory;
|
||||
private static X509TrustManager x509TrustManager;
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private static boolean ifCertificateVerify = false;
|
||||
private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8");
|
||||
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
|
||||
private OkHttpClient client;
|
||||
|
||||
private static volatile FLCommunication communication;
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private OkHttpClient client;
|
||||
|
||||
private FLCommunication() {
|
||||
if (flParameter.getTimeOut() != 0) {
|
||||
timeOut = flParameter.getTimeOut();
|
||||
} else {
|
||||
timeOut = TIME_OUT;
|
||||
}
|
||||
ssl = flParameter.isUseSSL();
|
||||
ifCertificateVerify = flParameter.isUseSSL();
|
||||
client = getOkHttpClient();
|
||||
}
|
||||
|
||||
private static OkHttpClient getOkHttpClient() {
|
||||
X509TrustManager trustManager = new X509TrustManager() {
|
||||
|
||||
@Override
|
||||
public X509Certificate[] getAcceptedIssuers() {
|
||||
return new X509Certificate[]{};
|
||||
|
@ -89,14 +91,15 @@ public class FLCommunication implements IFLCommunication {
|
|||
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
|
||||
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
|
||||
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
|
||||
if (ssl) {
|
||||
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(), SSLSocketFactoryTools.getInstance().getmTrustManager());
|
||||
if (ifCertificateVerify) {
|
||||
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(),
|
||||
SSLSocketFactoryTools.getInstance().getmTrustManager());
|
||||
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
|
||||
} else {
|
||||
final SSLContext sslContext = SSLContext.getInstance("TLS");
|
||||
sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
|
||||
final javax.net.ssl.SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
|
||||
builder.sslSocketFactory(sslSocketFactory, trustManager);
|
||||
sslContext.init(null, trustAllCerts, Common.getSecureRandom());
|
||||
final SSLSocketFactory sslFactory = sslContext.getSocketFactory();
|
||||
builder.sslSocketFactory(sslFactory, trustManager);
|
||||
builder.hostnameVerifier(new HostnameVerifier() {
|
||||
@Override
|
||||
public boolean verify(String arg0, SSLSession arg1) {
|
||||
|
@ -104,14 +107,18 @@ public class FLCommunication implements IFLCommunication {
|
|||
}
|
||||
});
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
} catch (NoSuchAlgorithmException | KeyManagementException e) {
|
||||
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + e.getMessage()));
|
||||
throw new RuntimeException(e);
|
||||
} catch (NoSuchAlgorithmException | KeyManagementException ex) {
|
||||
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + ex.getMessage()));
|
||||
throw new IllegalArgumentException(ex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the singleton object of the class FLCommunication.
|
||||
*
|
||||
* @return the singleton object of the class FLCommunication.
|
||||
*/
|
||||
public static FLCommunication getInstance() {
|
||||
FLCommunication localRef = communication;
|
||||
if (localRef == null) {
|
||||
|
@ -138,6 +145,9 @@ public class FLCommunication implements IFLCommunication {
|
|||
if (!response.isSuccessful()) {
|
||||
throw new IOException("Unexpected code " + response);
|
||||
}
|
||||
if (response.body() == null) {
|
||||
throw new IOException("the returned response is null");
|
||||
}
|
||||
return response.body().bytes();
|
||||
}
|
||||
|
||||
|
@ -159,11 +169,10 @@ public class FLCommunication implements IFLCommunication {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Call call, IOException e) {
|
||||
asyncCallBack.onFailure(e);
|
||||
public void onFailure(Call call, IOException ioException) {
|
||||
asyncCallBack.onFailure(ioException);
|
||||
call.cancel();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,19 +13,40 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Define job result callback function.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class FLJobResultCallback implements IFLJobResultCallback {
|
||||
private static final Logger LOGGER = Logger.getLogger(FLJobResultCallback.class.toString());
|
||||
|
||||
/**
|
||||
* Called at the end of an iteration for Fl job
|
||||
*
|
||||
* @param modelName the name of model
|
||||
* @param iterationSeq Iteration number
|
||||
* @param resultCode Status Code
|
||||
*/
|
||||
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) {
|
||||
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode));
|
||||
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " +
|
||||
iterationSeq + " resultCode: " + resultCode));
|
||||
}
|
||||
|
||||
/**
|
||||
* Called on completion for Fl job
|
||||
*
|
||||
* @param modelName the name of model
|
||||
* @param iterationCount total Iteration numbers
|
||||
* @param resultCode Status Code
|
||||
*/
|
||||
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
|
||||
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode));
|
||||
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " +
|
||||
iterationCount + " resultCode: " + resultCode));
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,11 +16,15 @@
|
|||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.mindspore.flclient.cipher.BaseUtil;
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
|
||||
import mindspore.schema.CipherPublicParams;
|
||||
import mindspore.schema.FLPlan;
|
||||
import mindspore.schema.ResponseCode;
|
||||
|
@ -36,18 +40,21 @@ import java.util.Map;
|
|||
import java.util.TreeMap;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
/**
|
||||
* Defining the general process of federated learning tasks.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class FLLiteClient {
|
||||
private static final Logger LOGGER = Logger.getLogger(FLLiteClient.class.toString());
|
||||
private FLCommunication flCommunication;
|
||||
private static int iteration = 0;
|
||||
private static Map<String, float[]> mapBeforeTrain;
|
||||
|
||||
private double dpNormClipFactor = 1.0d;
|
||||
private double dpNormClipAdapt = 0.05d;
|
||||
private FLCommunication flCommunication;
|
||||
private FLClientStatus status;
|
||||
private int retCode;
|
||||
|
||||
private static int iteration = 0;
|
||||
private int iterations = 1;
|
||||
private int epochs = 1;
|
||||
private int batchSize = 16;
|
||||
|
@ -55,22 +62,21 @@ public class FLLiteClient {
|
|||
private byte[] prime;
|
||||
private int featureSize;
|
||||
private int trainDataSize;
|
||||
private double dpEps = 100;
|
||||
private double dpDelta = 0.01;
|
||||
public double dpNormClipFactor = 1.0;
|
||||
public double dpNormClipAdapt = 0.05;
|
||||
|
||||
private double dpEps = 100d;
|
||||
private double dpDelta = 0.01d;
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private SecureProtocol secureProtocol = new SecureProtocol();
|
||||
private static Map<String, float[]> mapBeforeTrain;
|
||||
private String nextRequestTime;
|
||||
|
||||
/**
|
||||
* Defining a constructor of teh class FLLiteClient.
|
||||
*/
|
||||
public FLLiteClient() {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
}
|
||||
|
||||
public int setGlobalParameters(ResponseFLJob flJob) {
|
||||
private int setGlobalParameters(ResponseFLJob flJob) {
|
||||
FLPlan flPlan = flJob.flPlanConfig();
|
||||
if (flPlan == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the FLPlan get from server is null"));
|
||||
|
@ -90,14 +96,22 @@ public class FLLiteClient {
|
|||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
trainLenet.setBatchSize(batchSize);
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid"));
|
||||
return -1;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize));
|
||||
CipherPublicParams cipherPublicParams = flPlan.cipher();
|
||||
if (cipherPublicParams == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the cipherPublicParams returned from server is null"));
|
||||
return -1;
|
||||
}
|
||||
String encryptLevel = cipherPublicParams.encryptType();
|
||||
if ("".equals(encryptLevel) || encryptLevel.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server is null, set the encryptLevel to NOT_ENCRYPT "));
|
||||
if (encryptLevel == null || encryptLevel.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server is null, set the " +
|
||||
"encryptLevel to NOT_ENCRYPT "));
|
||||
localFLParameter.setEncryptLevel(EncryptLevel.NOT_ENCRYPT.toString());
|
||||
} else {
|
||||
localFLParameter.setEncryptLevel(encryptLevel);
|
||||
|
@ -113,10 +127,10 @@ public class FLLiteClient {
|
|||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
|
||||
if (minSecretNum <= 0) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid: <=0"));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid:" +
|
||||
" <=0"));
|
||||
return -1;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
|
||||
break;
|
||||
case DP_ENCRYPT:
|
||||
dpEps = cipherPublicParams.dpEps();
|
||||
|
@ -124,53 +138,97 @@ public class FLLiteClient {
|
|||
dpNormClipFactor = cipherPublicParams.dpNormClip();
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " + dpNormClipFactor));
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " +
|
||||
dpNormClipFactor));
|
||||
break;
|
||||
default:
|
||||
LOGGER.info(Common.addTag("[startFLJob] NotEncrypt, do not set parameter for Encrypt"));
|
||||
LOGGER.info(Common.addTag("[startFLJob] NOT_ENCRYPT, do not set parameter for Encrypt"));
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain retCode returned by server.
|
||||
*
|
||||
* @return the retCode returned by server.
|
||||
*/
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain current iteration returned by server.
|
||||
*
|
||||
* @return the current iteration returned by server.
|
||||
*/
|
||||
public int getIteration() {
|
||||
return iteration;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain total iterations for the task returned by server.
|
||||
*
|
||||
* @return the total iterations for the task returned by server.
|
||||
*/
|
||||
public int getIterations() {
|
||||
return iterations;
|
||||
}
|
||||
|
||||
public int getEpochs() {
|
||||
return epochs;
|
||||
}
|
||||
|
||||
public int getBatchSize() {
|
||||
return batchSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain the returned timestamp for next request from server.
|
||||
*
|
||||
* @return the timestamp for next request.
|
||||
*/
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public void setNextRequestTime(String nextRequestTime) {
|
||||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the size of train date set.
|
||||
*
|
||||
* @param trainDataSize the size of train date set.
|
||||
*/
|
||||
public void setTrainDataSize(int trainDataSize) {
|
||||
this.trainDataSize = trainDataSize;
|
||||
}
|
||||
|
||||
public FLClientStatus checkStatus() {
|
||||
return this.status;
|
||||
/**
|
||||
* Obtain the dpNormClipFactor.
|
||||
*
|
||||
* @return the dpNormClipFactor.
|
||||
*/
|
||||
public double getDpNormClipFactor() {
|
||||
return dpNormClipFactor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain the dpNormClipAdapt.
|
||||
*
|
||||
* @return the dpNormClipAdapt.
|
||||
*/
|
||||
public double getDpNormClipAdapt() {
|
||||
return dpNormClipAdapt;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the dpNormClipAdapt.
|
||||
*
|
||||
* @param dpNormClipAdapt the dpNormClipAdapt.
|
||||
*/
|
||||
public void setDpNormClipAdapt(double dpNormClipAdapt) {
|
||||
this.dpNormClipAdapt = dpNormClipAdapt;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send serialized request message of startFLJob to server.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus startFLJob() {
|
||||
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify server===================================="));
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify " +
|
||||
"server===================================="));
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
StartFLJob startFLJob = StartFLJob.getInstance();
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
|
@ -179,8 +237,9 @@ public class FLLiteClient {
|
|||
long start = Common.startTime("single startFLJob");
|
||||
LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length));
|
||||
byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg);
|
||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(message)) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] the server is not ready now, need wait some time and request " +
|
||||
"again"));
|
||||
status = FLClientStatus.RESTART;
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
|
@ -193,14 +252,15 @@ public class FLLiteClient {
|
|||
status = judgeStartFLJob(startFLJob, responseDataBuf);
|
||||
retCode = responseDataBuf.retcode();
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage()));
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " +
|
||||
e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
public FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
|
||||
private FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
|
||||
iteration = responseDataBuf.iteration();
|
||||
FLClientStatus response = startFLJob.doResponse(responseDataBuf);
|
||||
status = response;
|
||||
|
@ -218,6 +278,10 @@ public class FLLiteClient {
|
|||
break;
|
||||
case RESTART:
|
||||
FLPlan flPlan = responseDataBuf.flPlanConfig();
|
||||
if (flPlan == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the flPlan returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
iterations = flPlan.iterations();
|
||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
|
||||
nextRequestTime = responseDataBuf.nextReqTime();
|
||||
|
@ -226,14 +290,21 @@ public class FLLiteClient {
|
|||
LOGGER.severe(Common.addTag("[startFLJob] startFLJob failed"));
|
||||
break;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range <SUCCESS, WAIT, FAILED, Restart>"));
|
||||
LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range " +
|
||||
"<SUCCESS, WAIT, FAILED, Restart>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* Define the training process.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus localTrain() {
|
||||
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "===================================="));
|
||||
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration +
|
||||
"===================================="));
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
|
@ -254,12 +325,22 @@ public class FLLiteClient {
|
|||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[train] the flName is not valid"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send serialized request message of updateModel to server.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus updateModel() {
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
UpdateModel updateModelBuf = UpdateModel.getInstance();
|
||||
byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize);
|
||||
if (updateModelBuf.getStatus() == FLClientStatus.FAILED) {
|
||||
|
@ -270,8 +351,9 @@ public class FLLiteClient {
|
|||
long start = Common.startTime("single updateModel");
|
||||
LOGGER.info(Common.addTag("[updateModel] the request message length: " + updateModelBuffer.length));
|
||||
byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
|
||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[updateModel] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(message)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] the server is not ready now, need wait some time and request" +
|
||||
" again"));
|
||||
status = FLClientStatus.RESTART;
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
|
@ -288,23 +370,31 @@ public class FLLiteClient {
|
|||
}
|
||||
LOGGER.info(Common.addTag("[updateModel] get response from server ok!"));
|
||||
} catch (IOException e) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage()));
|
||||
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " +
|
||||
e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* Send serialized request message of getModel to server.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus getModel() {
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
GetModel getModelBuf = GetModel.getInstance();
|
||||
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration);
|
||||
try {
|
||||
long start = Common.startTime("single getModel");
|
||||
LOGGER.info(Common.addTag("[getModel] the request message length: " + buffer.length));
|
||||
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(message)) {
|
||||
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
|
||||
"again"));
|
||||
status = FLClientStatus.WAIT;
|
||||
return status;
|
||||
}
|
||||
|
@ -327,6 +417,12 @@ public class FLLiteClient {
|
|||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain the weight of the model before training.
|
||||
*
|
||||
* @param map a map to store the weight of the model.
|
||||
* @return the weight.
|
||||
*/
|
||||
public static synchronized Map<String, float[]> getOldMapCopy(Map<String, float[]> map) {
|
||||
if (mapBeforeTrain == null) {
|
||||
Map<String, float[]> copyMap = new TreeMap<>();
|
||||
|
@ -334,7 +430,8 @@ public class FLLiteClient {
|
|||
float[] data = map.get(key);
|
||||
int dataLen = data.length;
|
||||
float[] weights = new float[dataLen];
|
||||
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
|
||||
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) &&
|
||||
(key.indexOf("learning") < 0)) {
|
||||
for (int j = 0; j < dataLen; j++) {
|
||||
float weight = data[j];
|
||||
weights[j] = weight;
|
||||
|
@ -348,7 +445,8 @@ public class FLLiteClient {
|
|||
float[] data = map.get(key);
|
||||
float[] copyData = mapBeforeTrain.get(key);
|
||||
int dataLen = data.length;
|
||||
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
|
||||
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) &&
|
||||
(key.indexOf("learning") < 0)) {
|
||||
for (int j = 0; j < dataLen; j++) {
|
||||
copyData[j] = data[j];
|
||||
}
|
||||
|
@ -358,18 +456,25 @@ public class FLLiteClient {
|
|||
return mapBeforeTrain;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain pairwise mask and individual mask.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus getFeatureMask() {
|
||||
FLClientStatus curStatus;
|
||||
switch (localFLParameter.getEncryptLevel()) {
|
||||
case PW_ENCRYPT:
|
||||
LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" + localFLParameter.getEncryptLevel().toString() + ">"));
|
||||
LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" +
|
||||
localFLParameter.getEncryptLevel().toString() + ">"));
|
||||
secureProtocol.setPWParameter(iteration, minSecretNum, prime, featureSize);
|
||||
curStatus = secureProtocol.pwCreateMask();
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
nextRequestTime = secureProtocol.getNextRequestTime();
|
||||
}
|
||||
retCode = secureProtocol.getRetCode();
|
||||
LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" + localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
|
||||
LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" +
|
||||
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
|
||||
return curStatus;
|
||||
case DP_ENCRYPT:
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
|
@ -388,7 +493,7 @@ public class FLLiteClient {
|
|||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[Encrypt] set parameters for DPEncrypt!"));
|
||||
LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case NOT_ENCRYPT:
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
|
@ -401,6 +506,11 @@ public class FLLiteClient {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconstruct the secrets used for unmasking model weights.
|
||||
*
|
||||
* @return current status code in client.
|
||||
*/
|
||||
public FLClientStatus unMasking() {
|
||||
FLClientStatus curStatus;
|
||||
switch (localFLParameter.getEncryptLevel()) {
|
||||
|
@ -413,7 +523,7 @@ public class FLLiteClient {
|
|||
}
|
||||
return curStatus;
|
||||
case DP_ENCRYPT:
|
||||
LOGGER.info(Common.addTag("[Encrypt] DPEncrypt do not need unmasking"));
|
||||
LOGGER.info(Common.addTag("[Encrypt] DP_ENCRYPT do not need unmasking"));
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
return FLClientStatus.SUCCESS;
|
||||
case NOT_ENCRYPT:
|
||||
|
@ -427,18 +537,26 @@ public class FLLiteClient {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluate model after getting model from server.
|
||||
*
|
||||
* @return the status code in client.
|
||||
*/
|
||||
public FLClientStatus evaluateModel() {
|
||||
status = FLClientStatus.SUCCESS;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
LOGGER.info(Common.addTag("===================================evaluate model after getting model from server==================================="));
|
||||
LOGGER.info(Common.addTag("===================================evaluate model after getting model from " +
|
||||
"server==================================="));
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
float acc = 0;
|
||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
|
||||
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
|
||||
flParameter.getIdsFile(), true);
|
||||
if (dataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the return dataSize<=0"));
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the " +
|
||||
"return dataSize<=0"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
|
@ -447,47 +565,66 @@ public class FLLiteClient {
|
|||
} else {
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile());
|
||||
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
|
||||
flParameter.getIdsFile());
|
||||
if (dataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0"));
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the " +
|
||||
"return dataSize<=0"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
acc = alTrainBert.evalModel();
|
||||
}
|
||||
if (acc == Float.NaN) {
|
||||
if (Float.isNaN(acc)) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <evalModel>: the return acc is NAN"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
|
||||
flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() +
|
||||
" idsFile: " + flParameter.getIdsFile()));
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0], flParameter.getTestDataset().split(",")[1]);
|
||||
if (flParameter.getTestDataset().split(",").length < 2) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] the set testDataPath for lenet is not valid, should be the " +
|
||||
"format of <data.bin,label.bin> "));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0],
|
||||
flParameter.getTestDataset().split(",")[1]);
|
||||
if (dataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return dataSize<=0"));
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return " +
|
||||
"dataSize<=0"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
float acc = trainLenet.evalModel();
|
||||
if (acc == Float.NaN) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc is NAN"));
|
||||
if (Float.isNaN(acc)) {
|
||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc" +
|
||||
" is NAN"));
|
||||
status = FLClientStatus.FAILED;
|
||||
retCode = ResponseCode.RequestError;
|
||||
return status;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset().split(",")[0] + " labelPath: " + flParameter.getTestDataset().split(",")[1]));
|
||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
|
||||
flParameter.getTestDataset().split(",")[0] + " labelPath: " +
|
||||
flParameter.getTestDataset().split(",")[1]));
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param dataPath, train or test dataset and label set
|
||||
* Set date path.
|
||||
*
|
||||
* @param dataPath, train or test dataset and label set.
|
||||
* @return date size.
|
||||
*/
|
||||
public int setInput(String dataPath) {
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
|
@ -496,15 +633,18 @@ public class FLLiteClient {
|
|||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
dataSize = alTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
|
||||
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
|
||||
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " " +
|
||||
"vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
if (dataPath.split(",").length < 2) {
|
||||
LOGGER.info(Common.addTag("[set input] the set dataPath for lenet is not valid, should be the format of <data.bin,label.bin>"));
|
||||
LOGGER.severe(Common.addTag("[set input] the set dataPath for lenet is not valid, should be the " +
|
||||
"format of <data.bin,label.bin> "));
|
||||
return -1;
|
||||
}
|
||||
dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]);
|
||||
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " + +dataSize + " labelPath: " + dataPath.split(",")[1]));
|
||||
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " +
|
||||
dataSize + " labelPath: " + dataPath.split(",")[1]));
|
||||
}
|
||||
if (dataSize <= 0) {
|
||||
retCode = ResponseCode.RequestError;
|
||||
|
@ -513,36 +653,48 @@ public class FLLiteClient {
|
|||
return dataSize;
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialization session.
|
||||
*
|
||||
* @return the status code in client.
|
||||
*/
|
||||
public FLClientStatus initSession() {
|
||||
int tag = 0;
|
||||
retCode = ResponseCode.SUCCEED;
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
|
||||
"Train Session============="));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
|
||||
"is -1"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
|
||||
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " +
|
||||
"Create inference Session============="));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
|
||||
"Train Session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
}
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is " +
|
||||
"-1"));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
/**
|
||||
* Free session.
|
||||
*/
|
||||
protected void freeSession() {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("===========free train session============="));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
|
@ -558,5 +710,4 @@ public class FLLiteClient {
|
|||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,22 +13,36 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.UUID;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defines global parameters used during federated learning and these parameters are provided for users to set.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class FLParameter {
|
||||
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
|
||||
|
||||
/**
|
||||
* The timeout interval for communication on the device.
|
||||
*/
|
||||
public static final int TIME_OUT = 100;
|
||||
|
||||
/**
|
||||
* The waiting time of repeated requests.
|
||||
*/
|
||||
public static final int SLEEP_TIME = 1000;
|
||||
private static volatile FLParameter flParameter;
|
||||
|
||||
private String hostName;
|
||||
private String domainName;
|
||||
private String certPath;
|
||||
private boolean useHttps = false;
|
||||
|
||||
private String trainDataset;
|
||||
private String vocabFile = "null";
|
||||
private String idsFile = "null";
|
||||
|
@ -37,22 +51,21 @@ public class FLParameter {
|
|||
private String trainModelPath;
|
||||
private String inferModelPath;
|
||||
private String clientID;
|
||||
private String ip;
|
||||
private int port;
|
||||
private boolean useSSL = false;
|
||||
private int timeOut;
|
||||
private int sleepTime;
|
||||
private boolean useElb = false;
|
||||
private boolean ifUseElb = false;
|
||||
private int serverNum = 1;
|
||||
|
||||
private boolean timer = true;
|
||||
private int timeWindow = 6000;
|
||||
private int reRequestNum = timeWindow / SLEEP_TIME + 1;
|
||||
|
||||
private static volatile FLParameter flParameter;
|
||||
|
||||
private FLParameter() {}
|
||||
private FLParameter() {
|
||||
clientID = UUID.randomUUID().toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the singleton object of the class FLParameter.
|
||||
*
|
||||
* @return the singleton object of the class FLParameter.
|
||||
*/
|
||||
public static FLParameter getInstance() {
|
||||
FLParameter localRef = flParameter;
|
||||
if (localRef == null) {
|
||||
|
@ -66,95 +79,100 @@ public class FLParameter {
|
|||
return localRef;
|
||||
}
|
||||
|
||||
public String getHostName() {
|
||||
if ("".equals(hostName) || hostName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <hostName> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
public String getDomainName() {
|
||||
if (domainName == null || domainName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is null or empty, please set it " +
|
||||
"before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return hostName;
|
||||
return domainName;
|
||||
}
|
||||
|
||||
public void setHostName(String hostName) {
|
||||
this.hostName = hostName;
|
||||
public void setDomainName(String domainName) {
|
||||
if (domainName == null || domainName.isEmpty() || (!("https".equals(domainName.split(":")[0]) || "http".equals(domainName.split(":")[0])))) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is not valid, it should be like " +
|
||||
"as https://...... or http://......, please check it before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.domainName = domainName;
|
||||
}
|
||||
|
||||
public String getCertPath() {
|
||||
if ("".equals(certPath) || certPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
if (certPath == null || certPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null or empty, please set it " +
|
||||
"before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return certPath;
|
||||
}
|
||||
|
||||
public void setCertPath(String certPath) {
|
||||
certPath = Common.getRealPath(certPath);
|
||||
if (Common.checkPath(certPath)) {
|
||||
this.certPath = certPath;
|
||||
String realCertPath = Common.getRealPath(certPath);
|
||||
if (Common.checkPath(realCertPath)) {
|
||||
this.certPath = realCertPath;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it " +
|
||||
"before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isUseHttps() {
|
||||
return useHttps;
|
||||
}
|
||||
|
||||
public void setUseHttps(boolean useHttps) {
|
||||
this.useHttps = useHttps;
|
||||
}
|
||||
|
||||
public String getTrainDataset() {
|
||||
if ("".equals(trainDataset) || trainDataset.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
if (trainDataset == null || trainDataset.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null or empty, please set " +
|
||||
"it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return trainDataset;
|
||||
}
|
||||
|
||||
public void setTrainDataset(String trainDataset) {
|
||||
trainDataset = Common.getRealPath(trainDataset);
|
||||
if (Common.checkPath(trainDataset)) {
|
||||
this.trainDataset = trainDataset;
|
||||
String realTrainDataset = Common.getRealPath(trainDataset);
|
||||
if (Common.checkPath(realTrainDataset)) {
|
||||
this.trainDataset = realTrainDataset;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it " +
|
||||
"before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getVocabFile() {
|
||||
if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before " +
|
||||
"use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return vocabFile;
|
||||
}
|
||||
|
||||
public void setVocabFile(String vocabFile) {
|
||||
vocabFile = Common.getRealPath(vocabFile);
|
||||
if (Common.checkPath(vocabFile)) {
|
||||
this.vocabFile = vocabFile;
|
||||
String realVocabFile = Common.getRealPath(vocabFile);
|
||||
if (Common.checkPath(realVocabFile)) {
|
||||
this.vocabFile = realVocabFile;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it " +
|
||||
"before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getIdsFile() {
|
||||
if ("null".equals(idsFile) && ALBERT.equals(flName)) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return idsFile;
|
||||
}
|
||||
|
||||
public void setIdsFile(String idsFile) {
|
||||
idsFile = Common.getRealPath(idsFile);
|
||||
if (Common.checkPath(idsFile)) {
|
||||
this.idsFile = idsFile;
|
||||
String realIdsFile = Common.getRealPath(idsFile);
|
||||
if (Common.checkPath(realIdsFile)) {
|
||||
this.idsFile = realIdsFile;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it " +
|
||||
"before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,19 +181,21 @@ public class FLParameter {
|
|||
}
|
||||
|
||||
public void setTestDataset(String testDataset) {
|
||||
testDataset = Common.getRealPath(testDataset);
|
||||
if (Common.checkPath(testDataset)) {
|
||||
this.testDataset = testDataset;
|
||||
String realTestDataset = Common.getRealPath(testDataset);
|
||||
if (Common.checkPath(realTestDataset)) {
|
||||
this.testDataset = realTestDataset;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it " +
|
||||
"before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getFlName() {
|
||||
if ("".equals(flName) || flName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
if (flName == null || flName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null or empty, please set it " +
|
||||
"before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flName;
|
||||
}
|
||||
|
@ -184,61 +204,50 @@ public class FLParameter {
|
|||
if (Common.checkFLName(flName)) {
|
||||
this.flName = flName;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in flNameTrustList, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in FL_NAME_TRUST_LIST: " +
|
||||
Arrays.toString(Common.FL_NAME_TRUST_LIST.toArray(new String[0])) + ", please check it before " +
|
||||
"set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getTrainModelPath() {
|
||||
if ("".equals(trainModelPath) || trainModelPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
if (trainModelPath == null || trainModelPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null or empty, please set" +
|
||||
" it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return trainModelPath;
|
||||
}
|
||||
|
||||
public void setTrainModelPath(String trainModelPath) {
|
||||
trainModelPath = Common.getRealPath(trainModelPath);
|
||||
if (Common.checkPath(trainModelPath)) {
|
||||
this.trainModelPath = trainModelPath;
|
||||
String realTrainModelPath = Common.getRealPath(trainModelPath);
|
||||
if (Common.checkPath(realTrainModelPath)) {
|
||||
this.trainModelPath = realTrainModelPath;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check " +
|
||||
"it before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getInferModelPath() {
|
||||
if ("".equals(inferModelPath) || inferModelPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
if (inferModelPath == null || inferModelPath.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null or empty, please set" +
|
||||
" it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return inferModelPath;
|
||||
}
|
||||
|
||||
public void setInferModelPath(String inferModelPath) {
|
||||
inferModelPath = Common.getRealPath(inferModelPath);
|
||||
if (Common.checkPath(inferModelPath)) {
|
||||
this.inferModelPath = inferModelPath;
|
||||
String realInferModelPath = Common.getRealPath(inferModelPath);
|
||||
if (Common.checkPath(realInferModelPath)) {
|
||||
this.inferModelPath = realInferModelPath;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public String getIp() {
|
||||
if ("".equals(ip) || ip.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return ip;
|
||||
}
|
||||
|
||||
public void setIp(String ip) {
|
||||
if (Common.checkIP(ip)) {
|
||||
this.ip = ip;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is not valid, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check " +
|
||||
"it before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -250,23 +259,6 @@ public class FLParameter {
|
|||
this.useSSL = useSSL;
|
||||
}
|
||||
|
||||
public int getPort() {
|
||||
if (port == 0) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
return port;
|
||||
}
|
||||
|
||||
public void setPort(int port) {
|
||||
if (Common.checkPort(port)) {
|
||||
this.port = port;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is not valid, please check it before set"));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
}
|
||||
|
||||
public int getTimeOut() {
|
||||
return timeOut;
|
||||
}
|
||||
|
@ -284,17 +276,18 @@ public class FLParameter {
|
|||
}
|
||||
|
||||
public boolean isUseElb() {
|
||||
return useElb;
|
||||
return ifUseElb;
|
||||
}
|
||||
|
||||
public void setUseElb(boolean useElb) {
|
||||
this.useElb = useElb;
|
||||
public void setUseElb(boolean ifUseElb) {
|
||||
this.ifUseElb = ifUseElb;
|
||||
}
|
||||
|
||||
public int getServerNum() {
|
||||
if (serverNum <= 0) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> is <= 0, it should be > 0, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> <= 0, it should be > 0, please " +
|
||||
"set it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return serverNum;
|
||||
}
|
||||
|
@ -303,40 +296,11 @@ public class FLParameter {
|
|||
this.serverNum = serverNum;
|
||||
}
|
||||
|
||||
public boolean isTimer() {
|
||||
return timer;
|
||||
}
|
||||
|
||||
public void setTimer(boolean timer) {
|
||||
this.timer = timer;
|
||||
}
|
||||
|
||||
public int getTimeWindow() {
|
||||
return timeWindow;
|
||||
}
|
||||
|
||||
public void setTimeWindow(int timeWindow) {
|
||||
this.timeWindow = timeWindow;
|
||||
}
|
||||
|
||||
public int getReRequestNum() {
|
||||
return reRequestNum;
|
||||
}
|
||||
|
||||
public void setReRequestNum(int reRequestNum) {
|
||||
this.reRequestNum = reRequestNum;
|
||||
}
|
||||
|
||||
public String getClientID() {
|
||||
if ("".equals(clientID) || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
if (clientID == null || clientID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return clientID;
|
||||
}
|
||||
|
||||
public void setClientID(String clientID) {
|
||||
this.clientID = clientID;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,13 +13,19 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestGetModel;
|
||||
import mindspore.schema.ResponseCode;
|
||||
|
@ -29,57 +35,30 @@ import java.util.ArrayList;
|
|||
import java.util.Date;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
/**
|
||||
* Define the serialization method, handle the response message returned from server for getModel request.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class GetModel {
|
||||
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
|
||||
private static volatile GetModel getModel;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
class RequestGetModelBuilder {
|
||||
private FlatBufferBuilder builder;
|
||||
private int nameOffset = 0;
|
||||
private int iteration = 0;
|
||||
private int timeStampOffset = 0;
|
||||
|
||||
public RequestGetModelBuilder() {
|
||||
builder = new FlatBufferBuilder();
|
||||
}
|
||||
|
||||
public RequestGetModelBuilder flName(String name) {
|
||||
this.nameOffset = this.builder.createString(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestGetModelBuilder time() {
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
this.timeStampOffset = builder.createString(String.valueOf(time));
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestGetModelBuilder iteration(int iteration) {
|
||||
this.iteration = iteration;
|
||||
return this;
|
||||
}
|
||||
|
||||
public byte[] build() {
|
||||
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
|
||||
private static volatile GetModel getModel;
|
||||
|
||||
private GetModel() {
|
||||
}
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
|
||||
private GetModel() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the singleton object of the class GetModel.
|
||||
*
|
||||
* @return the singleton object of the class GetModel.
|
||||
*/
|
||||
public static GetModel getInstance() {
|
||||
GetModel localRef = getModel;
|
||||
if (localRef == null) {
|
||||
|
@ -93,7 +72,18 @@ public class GetModel {
|
|||
return localRef;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a flatBuffer builder of RequestGetModel.
|
||||
*
|
||||
* @param name the model name.
|
||||
* @param iteration current iteration of federated learning task.
|
||||
* @return the flatBuffer builder of RequestGetModel in byte[] format.
|
||||
*/
|
||||
public byte[] getRequestGetModel(String name, int iteration) {
|
||||
if (name == null || name.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[GetModel] the input parameter of <name> is null or empty, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
RequestGetModelBuilder builder = new RequestGetModelBuilder();
|
||||
return builder.iteration(iteration).flName(name).time().build();
|
||||
}
|
||||
|
@ -107,6 +97,10 @@ public class GetModel {
|
|||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
|
@ -116,36 +110,46 @@ public class GetModel {
|
|||
} else {
|
||||
continue;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength:" +
|
||||
" " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------"));
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " +
|
||||
"model-----------------"));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
|
||||
inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||
albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
|
||||
feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||
featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
|
@ -159,9 +163,14 @@ public class GetModel {
|
|||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
|
||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
|
||||
feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
||||
|
@ -174,7 +183,12 @@ public class GetModel {
|
|||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Handle the response message returned from server.
|
||||
*
|
||||
* @param responseDataBuf the response message returned from server.
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus doResponse(ResponseGetModel responseDataBuf) {
|
||||
LOGGER.info(Common.addTag("[getModel] ==========get model content is:================"));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode()));
|
||||
|
@ -186,13 +200,15 @@ public class GetModel {
|
|||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[getModel] getModel response success"));
|
||||
|
||||
if (ALBERT.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||
status = parseResponseAlbert(responseDataBuf);
|
||||
} else if (LENET.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
||||
status = parseResponseLenet(responseDataBuf);
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return status;
|
||||
case (ResponseCode.SucNotReady):
|
||||
|
@ -211,4 +227,42 @@ public class GetModel {
|
|||
}
|
||||
}
|
||||
|
||||
class RequestGetModelBuilder {
|
||||
private FlatBufferBuilder builder;
|
||||
private int nameOffset = 0;
|
||||
private int iteration = 0;
|
||||
private int timeStampOffset = 0;
|
||||
|
||||
public RequestGetModelBuilder() {
|
||||
builder = new FlatBufferBuilder();
|
||||
}
|
||||
|
||||
private RequestGetModelBuilder flName(String name) {
|
||||
if (name == null || name.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[GetModel] the input parameter of <name> is null or empty, please " +
|
||||
"check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.nameOffset = this.builder.createString(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
private RequestGetModelBuilder time() {
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
this.timeStampOffset = builder.createString(String.valueOf(time));
|
||||
return this;
|
||||
}
|
||||
|
||||
private RequestGetModelBuilder iteration(int iteration) {
|
||||
this.iteration = iteration;
|
||||
return this;
|
||||
}
|
||||
|
||||
private byte[] build() {
|
||||
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,25 +1,42 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Define asynchronous communication call back interface.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public interface IAsyncCallBack {
|
||||
public FLClientStatus onFailure(IOException exception);
|
||||
/**
|
||||
* Automatically invoked when the request fails.
|
||||
*
|
||||
* @param exception the catch exception.
|
||||
* @return the status code in client.
|
||||
*/
|
||||
FLClientStatus onFailure(IOException exception);
|
||||
|
||||
public FLClientStatus onResponse(byte[] msg);
|
||||
/**
|
||||
* Automatically invoked when a response message is processed.
|
||||
*
|
||||
* @param msg the response message.
|
||||
* @return the status code in client.
|
||||
*/
|
||||
FLClientStatus onResponse(byte[] msg);
|
||||
}
|
||||
|
|
|
@ -1,33 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
/**
|
||||
* @author smurf
|
||||
* Define basic communication interface.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public interface IFLCommunication {
|
||||
/**
|
||||
* Sets the timeout interval for communication on the device.
|
||||
*
|
||||
* @param timeout the timeout interval for communication on the device.
|
||||
* @throws TimeoutException catch TimeoutException.
|
||||
*/
|
||||
void setTimeOut(int timeout) throws TimeoutException;
|
||||
|
||||
public void setTimeOut(int timeout) throws TimeoutException;
|
||||
|
||||
public byte[] syncRequest(String url, byte[] msg) throws Exception;
|
||||
|
||||
public void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception;
|
||||
}
|
||||
/**
|
||||
* Synchronization request function.
|
||||
*
|
||||
* @param url the URL for device-sever interaction set by user.
|
||||
* @param msg the message need to be sent to server.
|
||||
* @return the response message.
|
||||
* @throws Exception catch Exception.
|
||||
*/
|
||||
byte[] syncRequest(String url, byte[] msg) throws Exception;
|
||||
|
||||
/**
|
||||
* Asynchronous request function.
|
||||
*
|
||||
* @param url the URL for device-sever interaction set by user.
|
||||
* @param msg the message need to be sent to server.
|
||||
* @param callBack the call back object.
|
||||
* @throws Exception catch Exception.
|
||||
*/
|
||||
void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception;
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,22 +13,30 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
/**
|
||||
* Define job result callback function interface.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public interface IFLJobResultCallback {
|
||||
/**
|
||||
* Called at the end of an iteration for Fl job
|
||||
* @param modelName the name of model
|
||||
* Called at the end of an iteration for Fl job
|
||||
*
|
||||
* @param modelName the name of model
|
||||
* @param iterationSeq Iteration number
|
||||
* @param resultCode Status Code
|
||||
* @param resultCode Status Code
|
||||
*/
|
||||
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode);
|
||||
void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode);
|
||||
|
||||
/**
|
||||
* Called on completion for Fl job
|
||||
* @param modelName the name of model
|
||||
*
|
||||
* @param modelName the name of model
|
||||
* @param iterationCount total Iteration numbers
|
||||
* @param resultCode Status Code
|
||||
* @param resultCode Status Code
|
||||
*/
|
||||
public void onFlJobFinished(String modelName, int iterationCount, int resultCode);
|
||||
void onFlJobFinished(String modelName, int iterationCount, int resultCode);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,28 +13,60 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import org.bouncycastle.math.ec.rfc7748.X25519;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Defines global parameters used internally during federated learning.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class LocalFLParameter {
|
||||
private static final Logger LOGGER = Logger.getLogger(LocalFLParameter.class.toString());
|
||||
|
||||
/**
|
||||
* Seed length used to generate random perturbations
|
||||
*/
|
||||
public static final int SEED_SIZE = 32;
|
||||
public static final int IVEC_LEN = 16;
|
||||
|
||||
/**
|
||||
* The length of IV value
|
||||
*/
|
||||
public static final int I_VEC_LEN = 16;
|
||||
|
||||
/**
|
||||
* The length of salt value
|
||||
*/
|
||||
public static final int SALT_SIZE = 32;
|
||||
|
||||
/**
|
||||
* the key length
|
||||
*/
|
||||
public static final int KEY_LEN = X25519.SCALAR_SIZE;
|
||||
|
||||
/**
|
||||
* The model name supported by federated learning tasks: "lenet".
|
||||
*/
|
||||
public static final String LENET = "lenet";
|
||||
|
||||
/**
|
||||
* The model name supported by federated learning tasks: "albert".
|
||||
*/
|
||||
public static final String ALBERT = "albert";
|
||||
private static volatile LocalFLParameter localFLParameter;
|
||||
|
||||
private List<String> classifierWeightName = new ArrayList<>();
|
||||
private List<String> albertWeightName = new ArrayList<>();
|
||||
|
||||
private String flID;
|
||||
private String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString();
|
||||
private String earlyStopMod = EarlyStopMod.NOT_EARLY_STOP.toString();
|
||||
private String serverMod = ServerMod.HYBRID_TRAINING.toString();
|
||||
private String safeMod = "The cluster is in safemode.";
|
||||
|
||||
private static volatile LocalFLParameter localFLParameter;
|
||||
|
||||
private LocalFLParameter() {
|
||||
// set classifierWeightName albertWeightName
|
||||
|
@ -42,6 +74,11 @@ public class LocalFLParameter {
|
|||
Common.setAlbertWeightName(albertWeightName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the singleton object of the class LocalFLParameter.
|
||||
*
|
||||
* @return the singleton object of the class LocalFLParameter.
|
||||
*/
|
||||
public static LocalFLParameter getInstance() {
|
||||
LocalFLParameter localRef = localFLParameter;
|
||||
if (localRef == null) {
|
||||
|
@ -57,8 +94,9 @@ public class LocalFLParameter {
|
|||
|
||||
public List<String> getClassifierWeightName() {
|
||||
if (classifierWeightName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please " +
|
||||
"set it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return classifierWeightName;
|
||||
}
|
||||
|
@ -69,8 +107,9 @@ public class LocalFLParameter {
|
|||
|
||||
public List<String> getAlbertWeightName() {
|
||||
if (albertWeightName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please " +
|
||||
"set it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return albertWeightName;
|
||||
}
|
||||
|
@ -80,14 +119,20 @@ public class LocalFLParameter {
|
|||
}
|
||||
|
||||
public String getFlID() {
|
||||
if ("".equals(flID) || flID == null) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before use"));
|
||||
throw new RuntimeException();
|
||||
if (flID == null || flID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before " +
|
||||
"use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flID;
|
||||
}
|
||||
|
||||
public void setFlID(String flID) {
|
||||
if (flID == null || flID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please check it before " +
|
||||
"set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
|
@ -96,6 +141,18 @@ public class LocalFLParameter {
|
|||
}
|
||||
|
||||
public void setEncryptLevel(String encryptLevel) {
|
||||
if (encryptLevel == null || encryptLevel.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is null, please check it " +
|
||||
"before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) &&
|
||||
(!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) &&
|
||||
(!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is " + encryptLevel + " ," +
|
||||
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.encryptLevel = encryptLevel;
|
||||
}
|
||||
|
||||
|
@ -104,6 +161,19 @@ public class LocalFLParameter {
|
|||
}
|
||||
|
||||
public void setEarlyStopMod(String earlyStopMod) {
|
||||
if (earlyStopMod == null || earlyStopMod.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is null, please check it " +
|
||||
"before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if ((!EarlyStopMod.NOT_EARLY_STOP.toString().equals(earlyStopMod)) &&
|
||||
(!EarlyStopMod.LOSS_ABS.toString().equals(earlyStopMod)) &&
|
||||
(!EarlyStopMod.LOSS_DIFF.toString().equals(earlyStopMod)) &&
|
||||
(!EarlyStopMod.WEIGHT_DIFF.toString().equals(earlyStopMod))) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is " + earlyStopMod + " ," +
|
||||
" it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.earlyStopMod = earlyStopMod;
|
||||
}
|
||||
|
||||
|
@ -112,14 +182,17 @@ public class LocalFLParameter {
|
|||
}
|
||||
|
||||
public void setServerMod(String serverMod) {
|
||||
if (serverMod == null || serverMod.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is null, please check it " +
|
||||
"before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if ((!ServerMod.HYBRID_TRAINING.toString().equals(serverMod)) &&
|
||||
(!ServerMod.FEDERATED_LEARNING.toString().equals(serverMod))) {
|
||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is " + serverMod + " , it " +
|
||||
"must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before set"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.serverMod = serverMod;
|
||||
}
|
||||
|
||||
public String getSafeMod() {
|
||||
return safeMod;
|
||||
}
|
||||
|
||||
public void setSafeMod(String safeMod) {
|
||||
this.safeMod = safeMod;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,33 +13,73 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.security.InvalidKeyException;
|
||||
import java.security.KeyManagementException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.NoSuchProviderException;
|
||||
import java.security.SignatureException;
|
||||
import java.security.cert.Certificate;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.CertificateFactory;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import javax.net.ssl.HostnameVerifier;
|
||||
import javax.net.ssl.SSLContext;
|
||||
import javax.net.ssl.SSLSession;
|
||||
import javax.net.ssl.SSLSocketFactory;
|
||||
import javax.net.ssl.TrustManager;
|
||||
import javax.net.ssl.X509TrustManager;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.security.InvalidKeyException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.NoSuchProviderException;
|
||||
import java.security.SignatureException;
|
||||
import java.security.cert.CertificateException;
|
||||
import java.security.cert.CertificateFactory;
|
||||
import java.security.cert.X509Certificate;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Define SSL socket factory tools for https communication.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class SSLSocketFactoryTools {
|
||||
private static final Logger LOGGER = Logger.getLogger(SSLSocketFactory.class.toString());
|
||||
private static volatile SSLSocketFactoryTools sslSocketFactoryTools;
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private X509Certificate x509Certificate;
|
||||
private SSLSocketFactory sslSocketFactory;
|
||||
private SSLContext sslContext;
|
||||
private MyTrustManager myTrustManager;
|
||||
private static volatile SSLSocketFactoryTools sslSocketFactoryTools;
|
||||
private final HostnameVerifier hostnameVerifier = new HostnameVerifier() {
|
||||
@Override
|
||||
public boolean verify(String hostname, SSLSession session) {
|
||||
if (hostname == null || hostname.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <hostname> is null or empty, " +
|
||||
"please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (session == null) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <session> is null, please " +
|
||||
"check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
String domainName = flParameter.getDomainName();
|
||||
if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the <domainName> is null or not valid, it should" +
|
||||
" be like as https://...... , please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (domainName.split("//")[1].split(":").length < 2) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the format of <domainName> is not valid, it " +
|
||||
"should be like as https://127.0.0.1:6666 when setting <useSSL> to true, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
String ip = domainName.split("//")[1].split(":")[0];
|
||||
return hostname.equals(ip);
|
||||
}
|
||||
};
|
||||
|
||||
private SSLSocketFactoryTools() {
|
||||
initSslSocketFactory();
|
||||
|
@ -52,14 +92,19 @@ public class SSLSocketFactoryTools {
|
|||
myTrustManager = new MyTrustManager(x509Certificate);
|
||||
sslContext.init(null, new TrustManager[]{
|
||||
myTrustManager
|
||||
}, new java.security.SecureRandom());
|
||||
}, Common.getSecureRandom());
|
||||
sslSocketFactory = sslContext.getSocketFactory();
|
||||
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " + e.getMessage()));
|
||||
} catch (NoSuchAlgorithmException | KeyManagementException ex) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " +
|
||||
ex.getMessage()));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the singleton object of the class SSLSocketFactoryTools.
|
||||
*
|
||||
* @return the singleton object of the class SSLSocketFactoryTools.
|
||||
*/
|
||||
public static SSLSocketFactoryTools getInstance() {
|
||||
SSLSocketFactoryTools localRef = sslSocketFactoryTools;
|
||||
if (localRef == null) {
|
||||
|
@ -73,29 +118,37 @@ public class SSLSocketFactoryTools {
|
|||
return localRef;
|
||||
}
|
||||
|
||||
public X509Certificate readCert(String assetName) {
|
||||
InputStream inputStream = null;
|
||||
try {
|
||||
inputStream = new FileInputStream(assetName);
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of read inputStream in readCert: " + e.getMessage()));
|
||||
private X509Certificate readCert(String assetName) {
|
||||
if (assetName == null || assetName.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <assetName> is null or empty, " +
|
||||
"please check!"));
|
||||
return null;
|
||||
}
|
||||
InputStream inputStream = null;
|
||||
X509Certificate cert = null;
|
||||
try {
|
||||
inputStream = new FileInputStream(assetName);
|
||||
CertificateFactory cf = CertificateFactory.getInstance("X.509");
|
||||
cert = (X509Certificate) cf.generateCertificate(inputStream);
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of creating CertificateFactory in readCert: " + e.getMessage()));
|
||||
Certificate certificate = cf.generateCertificate(inputStream);
|
||||
if (certificate instanceof X509Certificate) {
|
||||
cert = (X509Certificate) certificate;
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] cf.generateCertificate(inputStream) can not " +
|
||||
"convert to X509Certificate"));
|
||||
}
|
||||
} catch (FileNotFoundException | CertificateException ex) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch FileNotFoundException or CertificateException " +
|
||||
"when creating " +
|
||||
"CertificateFactory in readCert: " + ex.getMessage()));
|
||||
} finally {
|
||||
try {
|
||||
if (inputStream != null) {
|
||||
inputStream.close();
|
||||
}
|
||||
} catch (Throwable ex) {
|
||||
} catch (IOException ex) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch IOException: " + ex.getMessage()));
|
||||
}
|
||||
}
|
||||
|
||||
return cert;
|
||||
}
|
||||
|
||||
|
@ -111,7 +164,6 @@ public class SSLSocketFactoryTools {
|
|||
return myTrustManager;
|
||||
}
|
||||
|
||||
|
||||
private static final class MyTrustManager implements X509TrustManager {
|
||||
X509Certificate cert;
|
||||
|
||||
|
@ -126,27 +178,25 @@ public class SSLSocketFactoryTools {
|
|||
@Override
|
||||
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
|
||||
for (X509Certificate cert : chain) {
|
||||
|
||||
// Make sure that it hasn't expired.
|
||||
cert.checkValidity();
|
||||
|
||||
// Verify the certificate's public key chain.
|
||||
try {
|
||||
cert.verify(((X509Certificate) this.cert).getPublicKey());
|
||||
cert.verify(this.cert.getPublicKey());
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage()));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
|
||||
"NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage()));
|
||||
} catch (InvalidKeyException e) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch InvalidKeyException in checkServerTrusted: " + e.getMessage()));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
|
||||
"InvalidKeyException in checkServerTrusted: " + e.getMessage()));
|
||||
} catch (NoSuchProviderException e) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchProviderException in checkServerTrusted: " + e.getMessage()));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
|
||||
"NoSuchProviderException in checkServerTrusted: " + e.getMessage()));
|
||||
} catch (SignatureException e) {
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch SignatureException in checkServerTrusted: " + e.getMessage()));
|
||||
throw new RuntimeException();
|
||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
|
||||
"SignatureException in checkServerTrusted: " + e.getMessage()));
|
||||
}
|
||||
LOGGER.info(Common.addTag("checkServerTrusted success!"));
|
||||
LOGGER.info(Common.addTag("**********************checkServerTrusted success!**********************"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -155,14 +205,4 @@ public class SSLSocketFactoryTools {
|
|||
return new java.security.cert.X509Certificate[0];
|
||||
}
|
||||
}
|
||||
|
||||
private final HostnameVerifier hostnameVerifier = new HostnameVerifier() {
|
||||
@Override
|
||||
public boolean verify(String hostname, SSLSession session) {
|
||||
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] server hostname: " + flParameter.getHostName()));
|
||||
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] client request hostname: " + hostname));
|
||||
return hostname.equals(flParameter.getHostName());
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,127 +13,179 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
|
||||
import mindspore.schema.FeatureMap;
|
||||
|
||||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
/**
|
||||
* Defines encryption and decryption methods.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class SecureProtocol {
|
||||
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
|
||||
private static double deltaError = 1e-6d;
|
||||
private static Map<String, float[]> modelMap;
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private int iteration;
|
||||
private CipherClient cipher;
|
||||
private CipherClient cipherClient;
|
||||
private FLClientStatus status;
|
||||
private float[] featureMask = new float[0];
|
||||
private double dpEps;
|
||||
private double dpDelta;
|
||||
private double dpNormClip;
|
||||
private static double deltaError = 1e-6;
|
||||
private static Map<String, float[]> modelMap;
|
||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||
private int retCode;
|
||||
|
||||
/**
|
||||
* Obtain current status code in client.
|
||||
*
|
||||
* @return current status code in client.
|
||||
*/
|
||||
public FLClientStatus getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
public float[] getFeatureMask() {
|
||||
return featureMask;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain retCode returned by server.
|
||||
*
|
||||
* @return the retCode returned by server.
|
||||
*/
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
public SecureProtocol() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Setting parameters for pairwise masking.
|
||||
*
|
||||
* @param iter current iteration of federated learning task.
|
||||
* @param minSecretNum minimum number of secret fragments required to reconstruct a secret
|
||||
* @param prime teh big prime number used to split secrets into pieces
|
||||
* @param featureSize the total feature size in model
|
||||
*/
|
||||
public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) {
|
||||
this.iteration = iter;
|
||||
this.cipher = new CipherClient(iteration, minSecretNum, prime, featureSize);
|
||||
}
|
||||
|
||||
public FLClientStatus setDPParameter(int iter, double diffEps,
|
||||
double diffDelta, double diffNorm, Map<String, float[]> map) {
|
||||
try {
|
||||
this.iteration = iter;
|
||||
this.dpEps = diffEps;
|
||||
this.dpDelta = diffDelta;
|
||||
this.dpNormClip = diffNorm;
|
||||
this.modelMap = map;
|
||||
status = FLClientStatus.SUCCESS;
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[DPEncrypt] catch Exception in setDPParameter: " + e.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
if (prime == null || prime.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <prime> is null, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return status;
|
||||
this.iteration = iter;
|
||||
this.cipherClient = new CipherClient(iteration, minSecretNum, prime, featureSize);
|
||||
}
|
||||
|
||||
/**
|
||||
* Setting parameters for differential privacy.
|
||||
*
|
||||
* @param iter current iteration of federated learning task.
|
||||
* @param diffEps privacy budget eps of DP mechanism.
|
||||
* @param diffDelta privacy budget delta of DP mechanism.
|
||||
* @param diffNorm normClip factor of DP mechanism.
|
||||
* @param map model weights.
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus setDPParameter(int iter, double diffEps, double diffDelta, double diffNorm, Map<String,
|
||||
float[]> map) {
|
||||
this.iteration = iter;
|
||||
this.dpEps = diffEps;
|
||||
this.dpDelta = diffDelta;
|
||||
this.dpNormClip = diffNorm;
|
||||
this.modelMap = map;
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain the feature names that needed to be encrypted.
|
||||
*
|
||||
* @return the feature names that needed to be encrypted.
|
||||
*/
|
||||
public ArrayList<String> getEncryptFeatureName() {
|
||||
return encryptFeatureName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the parameter encryptFeatureName.
|
||||
*
|
||||
* @param encryptFeatureName the feature names that needed to be encrypted.
|
||||
*/
|
||||
public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) {
|
||||
this.encryptFeatureName = encryptFeatureName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtain the returned timestamp for next request from server.
|
||||
*
|
||||
* @return the timestamp for next request.
|
||||
*/
|
||||
public String getNextRequestTime() {
|
||||
return cipher.getNextRequestTime();
|
||||
return cipherClient.getNextRequestTime();
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate pairwise mask and individual mask.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus pwCreateMask() {
|
||||
LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "==============");
|
||||
LOGGER.info(String.format("[PairWiseMask] ==============request flID: %s ==============",
|
||||
localFLParameter.getFlID()));
|
||||
// round 0
|
||||
status = cipher.exchangeKeys();
|
||||
retCode = cipher.getRetCode();
|
||||
LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============");
|
||||
status = cipherClient.exchangeKeys();
|
||||
retCode = cipherClient.getRetCode();
|
||||
LOGGER.info(String.format("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: %s ",
|
||||
"============", status));
|
||||
if (status != FLClientStatus.SUCCESS) {
|
||||
return status;
|
||||
}
|
||||
// round 1
|
||||
try {
|
||||
status = cipher.shareSecrets();
|
||||
retCode = cipher.getRetCode();
|
||||
LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "=============");
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
status = cipherClient.shareSecrets();
|
||||
retCode = cipherClient.getRetCode();
|
||||
LOGGER.info(String.format("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: %s ",
|
||||
"=============", status));
|
||||
if (status != FLClientStatus.SUCCESS) {
|
||||
return status;
|
||||
}
|
||||
// round2
|
||||
try {
|
||||
featureMask = cipher.doubleMaskingWeight();
|
||||
retCode = cipher.getRetCode();
|
||||
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
|
||||
status = FLClientStatus.FAILED;
|
||||
featureMask = cipherClient.doubleMaskingWeight();
|
||||
if (featureMask == null || featureMask.length <= 0) {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] the returned featureMask from cipherClient.doubleMaskingWeight" +
|
||||
" is null, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
retCode = cipherClient.getRetCode();
|
||||
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add the pairwise mask and individual mask to model weights.
|
||||
*
|
||||
* @param builder the FlatBufferBuilder object used for serialization model weights.
|
||||
* @param trainDataSize trainDataSize tne size of train data set.
|
||||
* @return the serialized model weights after adding masks.
|
||||
*/
|
||||
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||
if (featureMask == null || featureMask.length == 0) {
|
||||
LOGGER.severe("[Encrypt] feature mask is null, please check");
|
||||
return new int[0];
|
||||
}
|
||||
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
|
||||
LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
|
||||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
|
@ -142,6 +194,9 @@ public class SecureProtocol {
|
|||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
int featureSize = encryptFeatureName.size();
|
||||
int[] featuresMap = new int[featureSize];
|
||||
|
@ -149,9 +204,13 @@ public class SecureProtocol {
|
|||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
LOGGER.info("[Encrypt] feature name: " + key + " feature size: " + data.length);
|
||||
LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
if (maskIndex >= featureMask.length) {
|
||||
LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check");
|
||||
return new int[0];
|
||||
}
|
||||
float maskData = rawData * trainDataSize + featureMask[maskIndex];
|
||||
maskIndex += 1;
|
||||
data[j] = maskData;
|
||||
|
@ -164,17 +223,23 @@ public class SecureProtocol {
|
|||
return featuresMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconstruct the secrets used for unmasking model weights.
|
||||
*
|
||||
* @return current status code in client.
|
||||
*/
|
||||
public FLClientStatus pwUnmasking() {
|
||||
status = cipher.reconstructSecrets(); // round3
|
||||
retCode = cipher.getRetCode();
|
||||
LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "=============");
|
||||
status = cipherClient.reconstructSecrets(); // round3
|
||||
retCode = cipherClient.getRetCode();
|
||||
LOGGER.info(String.format("[Encrypt] =============GetClientList+SendReconstructSecret: %s =============",
|
||||
status));
|
||||
return status;
|
||||
}
|
||||
|
||||
private static float calculateErf(double x) {
|
||||
double result = 0;
|
||||
private static float calculateErf(double erfInput) {
|
||||
double result = 0d;
|
||||
int segmentNum = 10000;
|
||||
double deltaX = x / segmentNum;
|
||||
double deltaX = erfInput / segmentNum;
|
||||
result += 1;
|
||||
for (int i = 1; i < segmentNum; i++) {
|
||||
result += 2 * Math.exp(-Math.pow(deltaX * i, 2));
|
||||
|
@ -183,33 +248,36 @@ public class SecureProtocol {
|
|||
return (float) (result * deltaX / Math.pow(Math.PI, 0.5));
|
||||
}
|
||||
|
||||
private static double calculatePhi(double t) {
|
||||
return 0.5 * (1.0 + calculateErf((t / Math.sqrt(2.0))));
|
||||
private static double calculatePhi(double phiInput) {
|
||||
return 0.5 * (1.0 + calculateErf((phiInput / Math.sqrt(2.0))));
|
||||
}
|
||||
|
||||
private static double calculateBPositive(double eps, double s) {
|
||||
return calculatePhi(Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
|
||||
private static double calculateBPositive(double eps, double calInput) {
|
||||
return calculatePhi(Math.sqrt(eps * calInput)) -
|
||||
Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
|
||||
}
|
||||
|
||||
private static double calculateBNegative(double eps, double s) {
|
||||
return calculatePhi(-Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
|
||||
private static double calculateBNegative(double eps, double calInput) {
|
||||
return calculatePhi(-Math.sqrt(eps * calInput)) -
|
||||
Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
|
||||
}
|
||||
|
||||
private static double calculateSPositive(double eps, double targetDelta, double sInf, double sSup) {
|
||||
double deltaSup = calculateBPositive(eps, sSup);
|
||||
private static double calculateSPositive(double eps, double targetDelta, double initSInf, double initSSup) {
|
||||
double deltaSup = calculateBPositive(eps, initSSup);
|
||||
double sInf = initSInf;
|
||||
double sSup = initSSup;
|
||||
while (deltaSup <= targetDelta) {
|
||||
sInf = sSup;
|
||||
sSup = 2 * sInf;
|
||||
deltaSup = calculateBPositive(eps, sSup);
|
||||
}
|
||||
|
||||
double sMid = sInf + (sSup - sInf) / 2.0;
|
||||
int iterMax = 1000;
|
||||
int iters = 0;
|
||||
while (true) {
|
||||
double b = calculateBPositive(eps, sMid);
|
||||
if (b <= targetDelta) {
|
||||
if (targetDelta - b <= deltaError) {
|
||||
double bPositive = calculateBPositive(eps, sMid);
|
||||
if (bPositive <= targetDelta) {
|
||||
if (targetDelta - bPositive <= deltaError) {
|
||||
break;
|
||||
} else {
|
||||
sInf = sMid;
|
||||
|
@ -226,8 +294,10 @@ public class SecureProtocol {
|
|||
return sMid;
|
||||
}
|
||||
|
||||
private static double calculateSNegative(double eps, double targetDelta, double sInf, double sSup) {
|
||||
double deltaSup = calculateBNegative(eps, sSup);
|
||||
private static double calculateSNegative(double eps, double targetDelta, double initSInf, double initSSup) {
|
||||
double deltaSup = calculateBNegative(eps, initSSup);
|
||||
double sInf = initSInf;
|
||||
double sSup = initSSup;
|
||||
while (deltaSup > targetDelta) {
|
||||
sInf = sSup;
|
||||
sSup = 2 * sInf;
|
||||
|
@ -238,9 +308,9 @@ public class SecureProtocol {
|
|||
int iterMax = 1000;
|
||||
int iters = 0;
|
||||
while (true) {
|
||||
double b = calculateBNegative(eps, sMid);
|
||||
if (b <= targetDelta) {
|
||||
if (targetDelta - b <= deltaError) {
|
||||
double bNegative = calculateBNegative(eps, sMid);
|
||||
if (bNegative <= targetDelta) {
|
||||
if (targetDelta - bNegative <= deltaError) {
|
||||
break;
|
||||
} else {
|
||||
sSup = sMid;
|
||||
|
@ -259,17 +329,26 @@ public class SecureProtocol {
|
|||
|
||||
private static double calculateSigma(double clipNorm, double eps, double targetDelta) {
|
||||
double deltaZero = calculateBPositive(eps, 0);
|
||||
double alpha = 1;
|
||||
double alpha = 1d;
|
||||
if (targetDelta > deltaZero) {
|
||||
double s = calculateSPositive(eps, targetDelta, 0, 1);
|
||||
alpha = Math.sqrt(1.0 + s / 2.0) - Math.sqrt(s / 2.0);
|
||||
double sPositive = calculateSPositive(eps, targetDelta, 0, 1);
|
||||
alpha = Math.sqrt(1.0 + sPositive / 2.0) - Math.sqrt(sPositive / 2.0);
|
||||
} else if (targetDelta < deltaZero) {
|
||||
double s = calculateSNegative(eps, targetDelta, 0, 1);
|
||||
alpha = Math.sqrt(1.0 + s / 2.0) + Math.sqrt(s / 2.0);
|
||||
double sNegative = calculateSNegative(eps, targetDelta, 0, 1);
|
||||
alpha = Math.sqrt(1.0 + sNegative / 2.0) + Math.sqrt(sNegative / 2.0);
|
||||
} else {
|
||||
LOGGER.info(Common.addTag("[Encrypt] targetDelta = deltaZero"));
|
||||
}
|
||||
return alpha * clipNorm / Math.sqrt(2.0 * eps);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add differential privacy mask to model weights.
|
||||
*
|
||||
* @param builder the FlatBufferBuilder object used for serialization model weights.
|
||||
* @param trainDataSize tne size of train data set.
|
||||
* @return the serialized model weights after adding masks.
|
||||
*/
|
||||
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||
// get feature map
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
|
@ -279,6 +358,9 @@ public class SecureProtocol {
|
|||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
Map<String, float[]> mapBeforeTrain = modelMap;
|
||||
int featureSize = encryptFeatureName.size();
|
||||
|
@ -286,19 +368,18 @@ public class SecureProtocol {
|
|||
double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta);
|
||||
LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "============="));
|
||||
|
||||
// prepare gaussian noise
|
||||
SecureRandom random = new SecureRandom();
|
||||
int randomInt = random.nextInt();
|
||||
Random r = new Random(randomInt);
|
||||
|
||||
// calculate l2-norm of all layers' update array
|
||||
double updateL2Norm = 0;
|
||||
double updateL2Norm = 0d;
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
if (j >= dataBeforeTrain.length) {
|
||||
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
||||
return new int[0];
|
||||
}
|
||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||
float updateData = rawData - rawDataBeforeTrain;
|
||||
updateL2Norm += updateData * updateData;
|
||||
|
@ -311,11 +392,26 @@ public class SecureProtocol {
|
|||
int[] featuresMap = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
if (!map.containsKey(key)) {
|
||||
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
|
||||
return new int[0];
|
||||
}
|
||||
float[] data = map.get(key);
|
||||
float[] data2 = new float[data.length];
|
||||
if (!mapBeforeTrain.containsKey(key)) {
|
||||
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");
|
||||
return new int[0];
|
||||
}
|
||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||
|
||||
// prepare gaussian noise
|
||||
SecureRandom secureRandom = Common.getSecureRandom();
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
if (j >= dataBeforeTrain.length) {
|
||||
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
||||
return new int[0];
|
||||
}
|
||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||
float updateData = rawData - rawDataBeforeTrain;
|
||||
|
||||
|
@ -323,7 +419,7 @@ public class SecureProtocol {
|
|||
updateData *= clipFactor;
|
||||
|
||||
// add noise
|
||||
double gaussianNoise = r.nextGaussian() * gaussianSigma;
|
||||
double gaussianNoise = secureRandom.nextGaussian() * gaussianSigma;
|
||||
updateData += gaussianNoise;
|
||||
data2[j] = rawDataBeforeTrain + updateData;
|
||||
data2[j] = data2[j] * trainDataSize;
|
||||
|
@ -335,5 +431,4 @@ public class SecureProtocol {
|
|||
}
|
||||
return featuresMap;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,8 +13,14 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
/**
|
||||
* The training mode of federated learning.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public enum ServerMod {
|
||||
FEDERATED_LEARNING,
|
||||
HYBRID_TRAINING
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
|
@ -13,13 +12,20 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
|
||||
import mindspore.schema.FLPlan;
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestFLJob;
|
||||
import mindspore.schema.ResponseCode;
|
||||
|
@ -28,15 +34,28 @@ import mindspore.schema.ResponseFLJob;
|
|||
import java.util.ArrayList;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
/**
|
||||
* StartFLJob
|
||||
*
|
||||
* @since 2021-08-25
|
||||
*/
|
||||
public class StartFLJob {
|
||||
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString());
|
||||
private static volatile StartFLJob startFLJob;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString());
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
|
||||
private int featureSize;
|
||||
private String nextRequestTime;
|
||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||
|
||||
private StartFLJob() {
|
||||
}
|
||||
|
||||
class RequestStartFLJobBuilder {
|
||||
private RequestFLJob requestFLJob;
|
||||
|
@ -51,21 +70,53 @@ public class StartFLJob {
|
|||
builder = new FlatBufferBuilder();
|
||||
}
|
||||
|
||||
/**
|
||||
* set flName
|
||||
*
|
||||
* @param name String
|
||||
* @return RequestStartFLJobBuilder
|
||||
*/
|
||||
public RequestStartFLJobBuilder flName(String name) {
|
||||
if (name == null || name.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <name> is null or empty, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.nameOffset = this.builder.createString(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* set id
|
||||
*
|
||||
* @param id String
|
||||
* @return RequestStartFLJobBuilder
|
||||
*/
|
||||
public RequestStartFLJobBuilder id(String id) {
|
||||
if (id == null || id.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <id> is null or empty, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.idOffset = this.builder.createString(id);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* set time
|
||||
*
|
||||
* @param timestamp long
|
||||
* @return RequestStartFLJobBuilder
|
||||
*/
|
||||
public RequestStartFLJobBuilder time(long timestamp) {
|
||||
this.timestampOffset = builder.createString(String.valueOf(timestamp));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* set dataSize
|
||||
*
|
||||
* @param dataSize int
|
||||
* @return RequestStartFLJobBuilder
|
||||
*/
|
||||
public RequestStartFLJobBuilder dataSize(int dataSize) {
|
||||
// temp code need confirm
|
||||
this.dataSize = dataSize;
|
||||
|
@ -73,11 +124,22 @@ public class StartFLJob {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* set iteration
|
||||
*
|
||||
* @param iteration iteration
|
||||
* @return RequestStartFLJobBuilder
|
||||
*/
|
||||
public RequestStartFLJobBuilder iteration(int iteration) {
|
||||
this.iteration = iteration;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* build protobuffer
|
||||
*
|
||||
* @return byte[] data
|
||||
*/
|
||||
public byte[] build() {
|
||||
int root = RequestFLJob.createRequestFLJob(this.builder, this.nameOffset, this.idOffset, this.iteration,
|
||||
this.dataSize, this.timestampOffset);
|
||||
|
@ -86,20 +148,11 @@ public class StartFLJob {
|
|||
}
|
||||
}
|
||||
|
||||
private static volatile StartFLJob startFLJob;
|
||||
|
||||
private FLClientStatus status;
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private int featureSize;
|
||||
private String nextRequestTime;
|
||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||
|
||||
private StartFLJob() {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* getInstance of StartFLJob
|
||||
*
|
||||
* @return StartFLJob instance
|
||||
*/
|
||||
public static StartFLJob getInstance() {
|
||||
StartFLJob localRef = startFLJob;
|
||||
if (localRef == null) {
|
||||
|
@ -117,6 +170,14 @@ public class StartFLJob {
|
|||
return nextRequestTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* get request start FLJob
|
||||
*
|
||||
* @param dataSize dataSize
|
||||
* @param iteration iteration
|
||||
* @param time time
|
||||
* @return byte[] data
|
||||
*/
|
||||
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) {
|
||||
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
|
||||
return builder.flName(flParameter.getFlName())
|
||||
|
@ -135,6 +196,7 @@ public class StartFLJob {
|
|||
return encryptFeatureName;
|
||||
}
|
||||
|
||||
|
||||
private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) {
|
||||
int fmCount = flJob.featureMapLength();
|
||||
encryptFeatureName.clear();
|
||||
|
@ -149,6 +211,10 @@ public class StartFLJob {
|
|||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||
albertFeatureMaps.add(feature);
|
||||
|
@ -160,19 +226,23 @@ public class StartFLJob {
|
|||
} else {
|
||||
continue;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
|
||||
"weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------"));
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
|
||||
"model-----------------"));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
|
||||
inferFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||
albertFeatureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
|
@ -182,16 +252,22 @@ public class StartFLJob {
|
|||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(featureName);
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
|
||||
"weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||
featureMaps);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||
return FLClientStatus.FAILED;
|
||||
|
@ -206,11 +282,16 @@ public class StartFLJob {
|
|||
encryptFeatureName.clear();
|
||||
for (int i = 0; i < fmCount; i++) {
|
||||
FeatureMap feature = flJob.featureMap(i);
|
||||
if (feature == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
String featureName = feature.weightFullname();
|
||||
featureMaps.add(feature);
|
||||
featureSize += feature.dataLength();
|
||||
encryptFeatureName.add(featureName);
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " +
|
||||
feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||
}
|
||||
int tag = 0;
|
||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||
|
@ -223,7 +304,22 @@ public class StartFLJob {
|
|||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
|
||||
/**
|
||||
* response res
|
||||
*
|
||||
* @param flJob ResponseFLJob
|
||||
* @return FLClientStatus
|
||||
*/
|
||||
public FLClientStatus doResponse(ResponseFLJob flJob) {
|
||||
if (flJob == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
FLPlan flPlanConfig = flJob.flPlanConfig();
|
||||
if (flPlanConfig == null) {
|
||||
LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
|
||||
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
|
||||
|
@ -236,11 +332,12 @@ public class StartFLJob {
|
|||
|
||||
switch (retCode) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
localFLParameter.setServerMod(flJob.flPlanConfig().serverMode());
|
||||
localFLParameter.setServerMod(flPlanConfig.serverMode());
|
||||
if (ALBERT.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
|
||||
status = parseResponseAlbert(flJob);
|
||||
} else if (LENET.equals(flParameter.getFlName())) {
|
||||
}
|
||||
if (LENET.equals(flParameter.getFlName())) {
|
||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
|
||||
status = parseResponseLenet(flJob);
|
||||
}
|
||||
|
@ -256,8 +353,4 @@ public class StartFLJob {
|
|||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus getStatus() {
|
||||
return this.status;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,35 +13,49 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
import mindspore.schema.ResponseGetModel;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.mindspore.flclient.model.AlInferBert;
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
|
||||
import mindspore.schema.ResponseGetModel;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* SyncFLJob defines the APIs for federated learning task.
|
||||
* API flJobRun() for starting federated learning on the device, the API modelInference() for inference on the
|
||||
* device, and the API getModel() for obtaining the latest model on the cloud.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class SyncFLJob {
|
||||
private static final Logger LOGGER = Logger.getLogger(SyncFLJob.class.toString());
|
||||
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
|
||||
private Map<String, float[]> oldFeatureMap;
|
||||
|
||||
public SyncFLJob() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a federated learning task on the device.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus flJobRun() {
|
||||
Common.setSecureRandom(new SecureRandom());
|
||||
localFLParameter.setFlID(flParameter.getClientID());
|
||||
FLLiteClient client = new FLLiteClient();
|
||||
FLClientStatus curStatus;
|
||||
|
@ -58,17 +72,14 @@ public class SyncFLJob {
|
|||
if (trainDataSize <= 0) {
|
||||
LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0"));
|
||||
curStatus = FLClientStatus.FAILED;
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
|
||||
client.getRetCode());
|
||||
break;
|
||||
}
|
||||
client.setTrainDataSize(trainDataSize);
|
||||
|
||||
// startFLJob
|
||||
curStatus = client.startFLJob();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.startFLJob();
|
||||
}
|
||||
curStatus = startFLJob(client);
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
continue;
|
||||
|
@ -100,11 +111,7 @@ public class SyncFLJob {
|
|||
LOGGER.info(Common.addTag("[train] train succeed"));
|
||||
|
||||
// updateModel
|
||||
curStatus = client.updateModel();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.updateModel();
|
||||
}
|
||||
curStatus = updateModel(client);
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
continue;
|
||||
|
@ -124,11 +131,7 @@ public class SyncFLJob {
|
|||
}
|
||||
|
||||
// getModel
|
||||
curStatus = client.getModel();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.getModel();
|
||||
}
|
||||
curStatus = getModel(client);
|
||||
if (curStatus == FLClientStatus.RESTART) {
|
||||
restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||
continue;
|
||||
|
@ -142,7 +145,8 @@ public class SyncFLJob {
|
|||
|
||||
// evaluate model after getting model from server
|
||||
if (flParameter.getTestDataset().equals("null")) {
|
||||
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the combine model"));
|
||||
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the model after getting" +
|
||||
" model from server"));
|
||||
} else {
|
||||
curStatus = client.evaluateModel();
|
||||
if (curStatus == FLClientStatus.FAILED) {
|
||||
|
@ -151,33 +155,68 @@ public class SyncFLJob {
|
|||
}
|
||||
LOGGER.info(Common.addTag("[evaluate] evaluate succeed"));
|
||||
}
|
||||
LOGGER.info(Common.addTag("========================================================the total response of " + client.getIteration() + ": " + curStatus + "======================================================================"));
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
|
||||
LOGGER.info(Common.addTag("========================================================the total response of "
|
||||
+ client.getIteration() + ": " + curStatus +
|
||||
"======================================================================"));
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
|
||||
client.getRetCode());
|
||||
} while (client.getIteration() < client.getIterations());
|
||||
client.finalize();
|
||||
client.freeSession();
|
||||
LOGGER.info(Common.addTag("flJobRun finish"));
|
||||
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private FLClientStatus startFLJob(FLLiteClient client) {
|
||||
FLClientStatus curStatus = client.startFLJob();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.startFLJob();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private FLClientStatus updateModel(FLLiteClient client) {
|
||||
FLClientStatus curStatus = client.updateModel();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.updateModel();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private FLClientStatus getModel(FLLiteClient client) {
|
||||
FLClientStatus curStatus = client.getModel();
|
||||
while (curStatus == FLClientStatus.WAIT) {
|
||||
waitSomeTime();
|
||||
curStatus = client.getModel();
|
||||
}
|
||||
return curStatus;
|
||||
}
|
||||
|
||||
private void updateDpNormClip(FLLiteClient client) {
|
||||
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
|
||||
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
|
||||
int currentIter = client.getIteration();
|
||||
Map<String, float[]> fedFeatureMap = getFeatureMap();
|
||||
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
|
||||
if (fedWeightUpdateNorm == -1) {
|
||||
LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " +
|
||||
"-1, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
|
||||
float newNormCLip = (float) client.dpNormClipFactor * fedWeightUpdateNorm;
|
||||
float newNormCLip = (float) client.getDpNormClipFactor() * fedWeightUpdateNorm;
|
||||
if (currentIter == 1) {
|
||||
client.dpNormClipAdapt = newNormCLip;
|
||||
client.setDpNormClipAdapt(newNormCLip);
|
||||
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
||||
} else {
|
||||
if (newNormCLip < client.dpNormClipAdapt) {
|
||||
client.dpNormClipAdapt = newNormCLip;
|
||||
if (newNormCLip < client.getDpNormClipAdapt()) {
|
||||
client.setDpNormClipAdapt(newNormCLip);
|
||||
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
||||
}
|
||||
}
|
||||
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.dpNormClipAdapt));
|
||||
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.getDpNormClipAdapt()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -190,11 +229,16 @@ public class SyncFLJob {
|
|||
}
|
||||
|
||||
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
|
||||
float updateL2Norm = 0;
|
||||
float updateL2Norm = 0f;
|
||||
for (String key : originalData.keySet()) {
|
||||
float[] data = originalData.get(key);
|
||||
float[] dataAfterUpdate = newData.get(key);
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
if (j >= dataAfterUpdate.length) {
|
||||
LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " +
|
||||
"please check");
|
||||
return -1;
|
||||
}
|
||||
float updateData = data[j] - dataAfterUpdate[j];
|
||||
updateL2Norm += updateData * updateData;
|
||||
}
|
||||
|
@ -215,12 +259,22 @@ public class SyncFLJob {
|
|||
return featureMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts an inference task on the device.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public int[] modelInference() {
|
||||
int[] labels = new int[0];
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
LOGGER.info(Common.addTag("===========model inference============="));
|
||||
labels = alInferBert.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile());
|
||||
labels = alInferBert.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset(),
|
||||
flParameter.getVocabFile(), flParameter.getIdsFile());
|
||||
if (labels == null || labels.length == 0) {
|
||||
LOGGER.severe("[model inference] the returned label from adInferBert.inferModel() is null, please " +
|
||||
"check");
|
||||
}
|
||||
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
||||
SessionUtil.free(alInferBert.getTrainSession());
|
||||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||
|
@ -228,49 +282,62 @@ public class SyncFLJob {
|
|||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
LOGGER.info(Common.addTag("===========model inference============="));
|
||||
labels = trainLenet.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset().split(",")[0]);
|
||||
if (labels == null || labels.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
|
||||
}
|
||||
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
||||
SessionUtil.free(trainLenet.getTrainSession());
|
||||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||
}
|
||||
if (labels.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
|
||||
}
|
||||
return labels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Obtains the latest model on the cloud.
|
||||
*
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus getModel() {
|
||||
Common.setSecureRandom(Common.getFastSecureRandom());
|
||||
int tag = 0;
|
||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
||||
FLClientStatus status;
|
||||
try {
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
|
||||
flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the " +
|
||||
"return is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " +
|
||||
flParameter.getInferModelPath() + " Create inference Session============="));
|
||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
|
||||
flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||
}
|
||||
if (tag == -1) {
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
|
||||
"is -1"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
FLCommunication flCommunication = FLCommunication.getInstance();
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||
flParameter.getDomainName());
|
||||
GetModel getModelBuf = GetModel.getInstance();
|
||||
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
|
||||
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(message)) {
|
||||
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
|
||||
"again"));
|
||||
status = FLClientStatus.WAIT;
|
||||
return status;
|
||||
}
|
||||
|
@ -279,8 +346,8 @@ public class SyncFLJob {
|
|||
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
|
||||
status = getModelBuf.doResponse(responseDataBuf);
|
||||
LOGGER.info(Common.addTag("[getModel] success!"));
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage()));
|
||||
} catch (Exception ex) {
|
||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + ex.getMessage()));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
|
@ -299,19 +366,16 @@ public class SyncFLJob {
|
|||
}
|
||||
|
||||
private void waitSomeTime() {
|
||||
if (flParameter.getSleepTime() != 0)
|
||||
if (flParameter.getSleepTime() != 0) {
|
||||
Common.sleep(flParameter.getSleepTime());
|
||||
else
|
||||
} else {
|
||||
Common.sleep(SLEEP_TIME);
|
||||
}
|
||||
}
|
||||
|
||||
private void waitNextReqTime(String nextReqTime) {
|
||||
if (flParameter.isTimer()) {
|
||||
long waitTime = Common.getWaitTime(nextReqTime);
|
||||
Common.sleep(waitTime);
|
||||
} else {
|
||||
waitSomeTime();
|
||||
}
|
||||
long waitTime = Common.getWaitTime(nextReqTime);
|
||||
Common.sleep(waitTime);
|
||||
}
|
||||
|
||||
private void restart(String tag, String nextReqTime, int iteration, int retcode) {
|
||||
|
@ -322,7 +386,8 @@ public class SyncFLJob {
|
|||
|
||||
private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) {
|
||||
LOGGER.info(Common.addTag(tag + " failed"));
|
||||
LOGGER.info(Common.addTag("========================================================the total response of " + iteration + ": " + curStatus + "======================================================================"));
|
||||
LOGGER.info(Common.addTag("=========================================the total response of " +
|
||||
iteration + ": " + curStatus + "========================================="));
|
||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
|
||||
}
|
||||
|
||||
|
@ -334,53 +399,27 @@ public class SyncFLJob {
|
|||
String flName = args[4];
|
||||
String trainModelPath = args[5];
|
||||
String inferModelPath = args[6];
|
||||
String clientID = args[7];
|
||||
String ip = args[8];
|
||||
boolean useSSL = Boolean.parseBoolean(args[9]);
|
||||
int port = Integer.parseInt(args[10]);
|
||||
int timeWindow = Integer.parseInt(args[11]);
|
||||
boolean useElb = Boolean.parseBoolean(args[12]);
|
||||
int serverNum = Integer.parseInt(args[13]);
|
||||
boolean useHttps = Boolean.parseBoolean(args[14]);
|
||||
String certPath = args[15];
|
||||
String task = args[16];
|
||||
boolean useSSL = Boolean.parseBoolean(args[7]);
|
||||
String domainName = args[8];
|
||||
boolean useElb = Boolean.parseBoolean(args[9]);
|
||||
int serverNum = Integer.parseInt(args[10]);
|
||||
String certPath = args[11];
|
||||
String task = args[12];
|
||||
|
||||
FLParameter flParameter = FLParameter.getInstance();
|
||||
LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset));
|
||||
LOGGER.info(Common.addTag("[args] vocabFile: " + vocabFile));
|
||||
LOGGER.info(Common.addTag("[args] idsFile: " + idsFile));
|
||||
LOGGER.info(Common.addTag("[args] testDataset: " + testDataset));
|
||||
LOGGER.info(Common.addTag("[args] flName: " + flName));
|
||||
LOGGER.info(Common.addTag("[args] trainModelPath: " + trainModelPath));
|
||||
LOGGER.info(Common.addTag("[args] inferModelPath: " + inferModelPath));
|
||||
LOGGER.info(Common.addTag("[args] clientID: " + clientID));
|
||||
LOGGER.info(Common.addTag("[args] ip: " + ip));
|
||||
LOGGER.info(Common.addTag("[args] useSSL: " + useSSL));
|
||||
LOGGER.info(Common.addTag("[args] port: " + port));
|
||||
LOGGER.info(Common.addTag("[args] timeWindow: " + timeWindow));
|
||||
LOGGER.info(Common.addTag("[args] useElb: " + useElb));
|
||||
LOGGER.info(Common.addTag("[args] serverNum: " + serverNum));
|
||||
LOGGER.info(Common.addTag("[args] useHttps: " + useHttps));
|
||||
LOGGER.info(Common.addTag("[args] certPath: " + certPath));
|
||||
LOGGER.info(Common.addTag("[args] task: " + task));
|
||||
|
||||
flParameter.setClientID(clientID);
|
||||
SyncFLJob syncFLJob = new SyncFLJob();
|
||||
if (task.equals("train")) {
|
||||
flParameter.setUseHttps(useHttps);
|
||||
if (useSSL) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setHostName(ip);
|
||||
flParameter.setTrainDataset(trainDataset);
|
||||
flParameter.setFlName(flName);
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setTestDataset(testDataset);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setIp(ip);
|
||||
flParameter.setUseSSL(useSSL);
|
||||
flParameter.setPort(port);
|
||||
flParameter.setTimeWindow(timeWindow);
|
||||
flParameter.setDomainName(domainName);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
if (ALBERT.equals(flName)) {
|
||||
|
@ -398,17 +437,14 @@ public class SyncFLJob {
|
|||
}
|
||||
syncFLJob.modelInference();
|
||||
} else if (task.equals("getModel")) {
|
||||
flParameter.setUseHttps(useHttps);
|
||||
if (useSSL) {
|
||||
flParameter.setCertPath(certPath);
|
||||
}
|
||||
flParameter.setHostName(ip);
|
||||
flParameter.setFlName(flName);
|
||||
flParameter.setTrainModelPath(trainModelPath);
|
||||
flParameter.setInferModelPath(inferModelPath);
|
||||
flParameter.setIp(ip);
|
||||
flParameter.setUseSSL(useSSL);
|
||||
flParameter.setPort(port);
|
||||
flParameter.setDomainName(domainName);
|
||||
flParameter.setUseElb(useElb);
|
||||
flParameter.setServerNum(serverNum);
|
||||
syncFLJob.getModel();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,10 +16,15 @@
|
|||
|
||||
package com.mindspore.flclient;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.model.AlTrainBert;
|
||||
import com.mindspore.flclient.model.SessionUtil;
|
||||
import com.mindspore.flclient.model.TrainLenet;
|
||||
|
||||
import mindspore.schema.FeatureMap;
|
||||
import mindspore.schema.RequestUpdateModel;
|
||||
import mindspore.schema.ResponseCode;
|
||||
|
@ -31,145 +36,31 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||
|
||||
/**
|
||||
* Define the serialization method, handle the response message returned from server for updateModel request.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class UpdateModel {
|
||||
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
|
||||
private static volatile UpdateModel updateModel;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mindspore-lite-jni");
|
||||
}
|
||||
|
||||
class RequestUpdateModelBuilder {
|
||||
private RequestUpdateModel requestUM;
|
||||
private FlatBufferBuilder builder;
|
||||
private int fmOffset = 0;
|
||||
private int nameOffset = 0;
|
||||
private int idOffset = 0;
|
||||
private int timestampOffset = 0;
|
||||
private int iteration = 0;
|
||||
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
|
||||
|
||||
public RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
|
||||
builder = new FlatBufferBuilder();
|
||||
this.encryptLevel = encryptLevel;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder flName(String name) {
|
||||
this.nameOffset = this.builder.createString(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder time() {
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
this.timestampOffset = builder.createString(String.valueOf(time));
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder iteration(int iteration) {
|
||||
this.iteration = iteration;
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder id(String id) {
|
||||
this.idOffset = this.builder.createString(id);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
|
||||
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
|
||||
switch (encryptLevel) {
|
||||
case PW_ENCRYPT:
|
||||
try {
|
||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
|
||||
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is null, please check");
|
||||
throw new RuntimeException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
||||
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
||||
return this;
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe("[Encrypt] catch error in maskModel: " + e.getMessage());
|
||||
throw new RuntimeException();
|
||||
}
|
||||
case DP_ENCRYPT:
|
||||
try {
|
||||
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
|
||||
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is null, please check");
|
||||
throw new RuntimeException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
||||
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
||||
return this;
|
||||
} catch (Exception e) {
|
||||
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
|
||||
throw new RuntimeException();
|
||||
}
|
||||
case NOT_ENCRYPT:
|
||||
default:
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
int featureSize = encryptFeatureName.size();
|
||||
int[] fmOffsets = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature size: " + data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
float rawData = data[j];
|
||||
data[j] = data[j] * trainDataSize;
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, data);
|
||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
fmOffsets[i] = featureMap;
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] build() {
|
||||
RequestUpdateModel.startRequestUpdateModel(this.builder);
|
||||
RequestUpdateModel.addFlName(builder, nameOffset);
|
||||
RequestUpdateModel.addFlId(this.builder, idOffset);
|
||||
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
|
||||
RequestUpdateModel.addIteration(builder, this.iteration);
|
||||
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
|
||||
int root = RequestUpdateModel.endRequestUpdateModel(builder);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
}
|
||||
}
|
||||
|
||||
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private String nextRequestTime;
|
||||
private FLClientStatus status;
|
||||
private static volatile UpdateModel updateModel;
|
||||
|
||||
private UpdateModel() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the singleton object of the class UpdateModel.
|
||||
*
|
||||
* @return the singleton object of the class UpdateModel.
|
||||
*/
|
||||
public static UpdateModel getInstance() {
|
||||
UpdateModel localRef = updateModel;
|
||||
if (localRef == null) {
|
||||
|
@ -183,25 +74,35 @@ public class UpdateModel {
|
|||
return localRef;
|
||||
}
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public FLClientStatus getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a flatBuffer builder of RequestUpdateModel.
|
||||
*
|
||||
* @param iteration current iteration of federated learning task.
|
||||
* @param secureProtocol the object that defines encryption and decryption methods.
|
||||
* @param trainDataSize the size of train date set.
|
||||
* @return the flatBuffer builder of RequestUpdateModel in byte[] format.
|
||||
*/
|
||||
public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) {
|
||||
RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel());
|
||||
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID()).featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
|
||||
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID())
|
||||
.featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle the response message returned from server.
|
||||
*
|
||||
* @param response the response message returned from server.
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus doResponse(ResponseUpdateModel response) {
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================"));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode()));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason()));
|
||||
LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime()));
|
||||
nextRequestTime = response.nextReqTime();
|
||||
switch (response.retcode()) {
|
||||
case (ResponseCode.SUCCEED):
|
||||
LOGGER.info(Common.addTag("[updateModel] updateModel success"));
|
||||
|
@ -213,8 +114,165 @@ public class UpdateModel {
|
|||
LOGGER.warning(Common.addTag("[updateModel] catch RequestError or SystemError"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[updateModel]the return <retcode> from server is invalid: " + response.retcode()));
|
||||
LOGGER.severe(Common.addTag("[updateModel]the return <retCode> from server is invalid: " +
|
||||
response.retcode()));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
class RequestUpdateModelBuilder {
|
||||
private RequestUpdateModel requestUM;
|
||||
private FlatBufferBuilder builder;
|
||||
private int fmOffset = 0;
|
||||
private int nameOffset = 0;
|
||||
private int idOffset = 0;
|
||||
private int timestampOffset = 0;
|
||||
private int iteration = 0;
|
||||
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
|
||||
|
||||
private RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
|
||||
builder = new FlatBufferBuilder();
|
||||
this.encryptLevel = encryptLevel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize the element flName in RequestUpdateModel.
|
||||
*
|
||||
* @param name the model name.
|
||||
* @return the RequestUpdateModelBuilder object.
|
||||
*/
|
||||
private RequestUpdateModelBuilder flName(String name) {
|
||||
if (name == null || name.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the parameter of <name> is null or empty, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.nameOffset = this.builder.createString(name);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize the element timestamp in RequestUpdateModel.
|
||||
*
|
||||
* @return the RequestUpdateModelBuilder object.
|
||||
*/
|
||||
private RequestUpdateModelBuilder time() {
|
||||
Date date = new Date();
|
||||
long time = date.getTime();
|
||||
this.timestampOffset = builder.createString(String.valueOf(time));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize the element iteration in RequestUpdateModel.
|
||||
*
|
||||
* @param iteration current iteration of federated learning task.
|
||||
* @return the RequestUpdateModelBuilder object.
|
||||
*/
|
||||
private RequestUpdateModelBuilder iteration(int iteration) {
|
||||
this.iteration = iteration;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize the element fl_id in RequestUpdateModel.
|
||||
*
|
||||
* @param id a number that uniquely identifies a client.
|
||||
* @return the RequestUpdateModelBuilder object.
|
||||
*/
|
||||
private RequestUpdateModelBuilder id(String id) {
|
||||
if (id == null || id.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the parameter of <id> is null or empty, please check!"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.idOffset = this.builder.createString(id);
|
||||
return this;
|
||||
}
|
||||
|
||||
private RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
|
||||
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
|
||||
switch (encryptLevel) {
|
||||
case PW_ENCRYPT:
|
||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
|
||||
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " +
|
||||
"null, please check");
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
||||
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
||||
return this;
|
||||
case DP_ENCRYPT:
|
||||
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
|
||||
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
|
||||
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " +
|
||||
"null, please check");
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
||||
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
||||
return this;
|
||||
case NOT_ENCRYPT:
|
||||
default:
|
||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||
if (flParameter.getFlName().equals(ALBERT)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
|
||||
flParameter.getFlName()));
|
||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
|
||||
".convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
} else if (flParameter.getFlName().equals(LENET)) {
|
||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
|
||||
flParameter.getFlName()));
|
||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||
if (map.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
|
||||
".convertTensorToFeatures>"));
|
||||
status = FLClientStatus.FAILED;
|
||||
}
|
||||
} else {
|
||||
LOGGER.severe(Common.addTag("[updateModel] the flName is not valid"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
int featureSize = encryptFeatureName.size();
|
||||
int[] fmOffsets = new int[featureSize];
|
||||
for (int i = 0; i < featureSize; i++) {
|
||||
String key = encryptFeatureName.get(i);
|
||||
float[] data = map.get(key);
|
||||
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
|
||||
"size: " + data.length));
|
||||
for (int j = 0; j < data.length; j++) {
|
||||
data[j] = data[j] * trainDataSize;
|
||||
}
|
||||
int featureName = builder.createString(key);
|
||||
int weight = FeatureMap.createDataVector(builder, data);
|
||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||
fmOffsets[i] = featureMap;
|
||||
}
|
||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a flatBuffer builder of RequestUpdateModel.
|
||||
*
|
||||
* @return the flatBuffer builder of RequestUpdateModel in byte[] format.
|
||||
*/
|
||||
private byte[] build() {
|
||||
RequestUpdateModel.startRequestUpdateModel(this.builder);
|
||||
RequestUpdateModel.addFlName(builder, nameOffset);
|
||||
RequestUpdateModel.addFlId(this.builder, idOffset);
|
||||
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
|
||||
RequestUpdateModel.addIteration(builder, this.iteration);
|
||||
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
|
||||
int root = RequestUpdateModel.endRequestUpdateModel(builder);
|
||||
builder.finish(root);
|
||||
return builder.sizedByteArray();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,21 +16,33 @@
|
|||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.I_VEC_LEN;
|
||||
import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import javax.crypto.Cipher;
|
||||
import javax.crypto.spec.IvParameterSpec;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.security.InvalidAlgorithmParameterException;
|
||||
import java.security.InvalidKeyException;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.Arrays;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.Cipher;
|
||||
import javax.crypto.IllegalBlockSizeException;
|
||||
import javax.crypto.NoSuchPaddingException;
|
||||
import javax.crypto.spec.IvParameterSpec;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
|
||||
/**
|
||||
* Define encryption and decryption methods.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class AESEncrypt {
|
||||
private static final Logger LOGGER = Logger.getLogger(AESEncrypt.class.toString());
|
||||
/**
|
||||
* 128, 192 or 256
|
||||
*/
|
||||
private static final int KEY_SIZE = 256;
|
||||
private static final int I_VEC_LEN = 16;
|
||||
|
||||
/**
|
||||
* encrypt/decrypt algorithm name
|
||||
|
@ -43,62 +55,145 @@ public class AESEncrypt {
|
|||
private static final String CIPHER_MODE_CTR = "AES/CTR/NoPadding";
|
||||
private static final String CIPHER_MODE_CBC = "AES/CBC/PKCS5PADDING";
|
||||
|
||||
private String CIPHER_MODE;
|
||||
private String cipherMod;
|
||||
|
||||
private static final int RANDOM_LEN = KEY_SIZE / 8;
|
||||
|
||||
private String iVecS = "1111111111111111";
|
||||
private byte[] iVec = iVecS.getBytes("utf-8");
|
||||
|
||||
public AESEncrypt(byte[] key, byte[] iVecIn, String mode) throws UnsupportedEncodingException {
|
||||
/**
|
||||
* Defining a Constructor of the class AESEncrypt.
|
||||
*
|
||||
* @param key the Key.
|
||||
* @param mode the encryption Mode.
|
||||
*/
|
||||
public AESEncrypt(byte[] key, String mode) {
|
||||
if (key == null) {
|
||||
LOGGER.severe(Common.addTag("Key is null"));
|
||||
return;
|
||||
}
|
||||
if (key.length != KEY_SIZE / 8) {
|
||||
if (key.length != KEY_LEN) {
|
||||
LOGGER.severe(Common.addTag("the length of key is not correct"));
|
||||
return;
|
||||
}
|
||||
|
||||
if (mode.contains("CBC")) {
|
||||
CIPHER_MODE = CIPHER_MODE_CBC;
|
||||
cipherMod = CIPHER_MODE_CBC;
|
||||
} else if (mode.contains("CTR")) {
|
||||
CIPHER_MODE = CIPHER_MODE_CTR;
|
||||
cipherMod = CIPHER_MODE_CTR;
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
if (iVecIn == null || iVecIn.length != I_VEC_LEN) {
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Defining the CBC encryption Mode.
|
||||
*
|
||||
* @param key the Key.
|
||||
* @param data the data to be encrypted.
|
||||
* @return the data to be encrypted.
|
||||
*/
|
||||
public byte[] encrypt(byte[] key, byte[] data) {
|
||||
if (key == null) {
|
||||
LOGGER.severe(Common.addTag("Key is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
if (data == null) {
|
||||
LOGGER.severe(Common.addTag("data is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
try {
|
||||
byte[] iVec = new byte[I_VEC_LEN];
|
||||
SecureRandom secureRandom = Common.getSecureRandom();
|
||||
secureRandom.nextBytes(iVec);
|
||||
|
||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||
Cipher cipher = Cipher.getInstance(cipherMod);
|
||||
|
||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
||||
byte[] encrypted = cipher.doFinal(data);
|
||||
byte[] encryptedAddIv = new byte[encrypted.length + iVec.length];
|
||||
System.arraycopy(iVec, 0, encryptedAddIv, 0, iVec.length);
|
||||
System.arraycopy(encrypted, 0, encryptedAddIv, iVec.length, encrypted.length);
|
||||
return encryptedAddIv;
|
||||
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
|
||||
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
|
||||
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException or " +
|
||||
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
|
||||
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
|
||||
return new byte[0];
|
||||
}
|
||||
iVec = iVecIn;
|
||||
}
|
||||
|
||||
public byte[] encrypt(byte[] key, byte[] data) throws Exception {
|
||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
|
||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
||||
byte[] encrypted = cipher.doFinal(data);
|
||||
String encryptResultStr = BaseUtil.byte2HexString(encrypted);
|
||||
return encrypted;
|
||||
/**
|
||||
* Defining the CTR encryption Mode.
|
||||
*
|
||||
* @param key the Key.
|
||||
* @param data the data to be encrypted.
|
||||
* @param iVec the IV value.
|
||||
* @return the data to be encrypted.
|
||||
*/
|
||||
public byte[] encryptCTR(byte[] key, byte[] data, byte[] iVec) {
|
||||
if (key == null) {
|
||||
LOGGER.severe(Common.addTag("Key is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
if (data == null) {
|
||||
LOGGER.severe(Common.addTag("data is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
if (iVec == null || iVec.length != I_VEC_LEN) {
|
||||
LOGGER.severe(Common.addTag("iVec is null or the length of iVec is not valid, it should be " + "I_VEC_LEN"
|
||||
));
|
||||
return new byte[0];
|
||||
}
|
||||
try {
|
||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||
Cipher cipher = Cipher.getInstance(cipherMod);
|
||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
||||
return cipher.doFinal(data);
|
||||
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
|
||||
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
|
||||
LOGGER.severe(Common.addTag("[encryptCTR] catch NoSuchAlgorithmException or " +
|
||||
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
|
||||
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
|
||||
return new byte[0];
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] encryptCTR(byte[] key, byte[] data) throws Exception {
|
||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
|
||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
||||
byte[] encrypted = cipher.doFinal(data);
|
||||
return encrypted;
|
||||
/**
|
||||
* Defining the decrypt method.
|
||||
*
|
||||
* @param key the Key.
|
||||
* @param encryptDataAddIv the data to be decrypted.
|
||||
* @return the data to be decrypted.
|
||||
*/
|
||||
public byte[] decrypt(byte[] key, byte[] encryptDataAddIv) {
|
||||
if (key == null) {
|
||||
LOGGER.severe(Common.addTag("Key is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
if (encryptDataAddIv == null) {
|
||||
LOGGER.severe(Common.addTag("encryptDataAddIv is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
if (encryptDataAddIv.length <= I_VEC_LEN) {
|
||||
LOGGER.severe(Common.addTag("the length of encryptDataAddIv is not valid: " + encryptDataAddIv.length +
|
||||
", it should be > " + I_VEC_LEN));
|
||||
return new byte[0];
|
||||
}
|
||||
try {
|
||||
byte[] iVec = Arrays.copyOfRange(encryptDataAddIv, 0, I_VEC_LEN);
|
||||
byte[] encryptData = Arrays.copyOfRange(encryptDataAddIv, I_VEC_LEN, encryptDataAddIv.length);
|
||||
SecretKeySpec sKeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||
Cipher cipher = Cipher.getInstance(cipherMod);
|
||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||
cipher.init(Cipher.DECRYPT_MODE, sKeySpec, iv);
|
||||
return cipher.doFinal(encryptData);
|
||||
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
|
||||
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
|
||||
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException or " +
|
||||
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
|
||||
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
|
||||
return new byte[0];
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] decrypt(byte[] key, byte[] encryptData) throws Exception {
|
||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
|
||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||
cipher.init(Cipher.DECRYPT_MODE, skeySpec, iv);
|
||||
byte[] origin = cipher.doFinal(encryptData);
|
||||
return origin;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* <p>
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import java.io.UnsupportedEncodingException;
|
||||
|
@ -21,14 +22,23 @@ import java.nio.charset.Charset;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Define conversion methods between basic data types.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class BaseUtil {
|
||||
private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
|
||||
|
||||
public BaseUtil() {
|
||||
}
|
||||
private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B',
|
||||
'C', 'D', 'E', 'F'};
|
||||
|
||||
/**
|
||||
* Convert byte[] to String in hexadecimal format.
|
||||
*
|
||||
* @param bytes the byte[] object.
|
||||
* @return the String object converted from byte[].
|
||||
*/
|
||||
public static String byte2HexString(byte[] bytes) {
|
||||
if (null == bytes) {
|
||||
if (bytes == null) {
|
||||
return null;
|
||||
} else if (bytes.length == 0) {
|
||||
return "";
|
||||
|
@ -36,14 +46,20 @@ public class BaseUtil {
|
|||
char[] chars = new char[bytes.length * 2];
|
||||
|
||||
for (int i = 0; i < bytes.length; ++i) {
|
||||
int b = bytes[i];
|
||||
chars[i * 2] = HEX_DIGITS[(b & 240) >> 4];
|
||||
chars[i * 2 + 1] = HEX_DIGITS[b & 15];
|
||||
int byteNum = bytes[i];
|
||||
chars[i * 2] = HEX_DIGITS[(byteNum & 240) >> 4];
|
||||
chars[i * 2 + 1] = HEX_DIGITS[byteNum & 15];
|
||||
}
|
||||
return new String(chars);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert String in hexadecimal format to byte[].
|
||||
*
|
||||
* @param str the String object.
|
||||
* @return the byte[] converted from String object.
|
||||
*/
|
||||
public static byte[] hexString2ByteArray(String str) {
|
||||
int length = str.length() / 2;
|
||||
byte[] bytes = new byte[length];
|
||||
|
@ -58,8 +74,13 @@ public class BaseUtil {
|
|||
return bytes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert byte[] to BigInteger.
|
||||
*
|
||||
* @param bytes the byte[] object.
|
||||
* @return the BigInteger object converted from byte[].
|
||||
*/
|
||||
public static BigInteger byteArray2BigInteger(byte[] bytes) {
|
||||
|
||||
BigInteger bigInteger = BigInteger.ZERO;
|
||||
for (int i = 0; i < bytes.length; ++i) {
|
||||
int intI = bytes[i];
|
||||
|
@ -72,6 +93,13 @@ public class BaseUtil {
|
|||
return bigInteger;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert String to BigInteger.
|
||||
*
|
||||
* @param str the String object.
|
||||
* @return the BigInteger object converted from String object.
|
||||
* @throws UnsupportedEncodingException if the encoding is not supported.
|
||||
*/
|
||||
public static BigInteger string2BigInteger(String str) throws UnsupportedEncodingException {
|
||||
StringBuilder res = new StringBuilder();
|
||||
byte[] bytes = String.valueOf(str).getBytes("UTF-8");
|
||||
|
@ -83,14 +111,20 @@ public class BaseUtil {
|
|||
return bigInteger;
|
||||
}
|
||||
|
||||
public static String bigInteger2String(BigInteger bigInteger) throws UnsupportedEncodingException {
|
||||
/**
|
||||
* Convert BigInteger to String.
|
||||
*
|
||||
* @param bigInteger the BigInteger object.
|
||||
* @return the String object converted from BigInteger.
|
||||
*/
|
||||
public static String bigInteger2String(BigInteger bigInteger) {
|
||||
StringBuilder res = new StringBuilder();
|
||||
List<Integer> lists = new ArrayList<>();
|
||||
BigInteger bi = bigInteger;
|
||||
BigInteger DIV = BigInteger.valueOf(256);
|
||||
BigInteger div = BigInteger.valueOf(256);
|
||||
while (bi.compareTo(BigInteger.ZERO) > 0) {
|
||||
lists.add(bi.mod(DIV).intValue());
|
||||
bi = bi.divide(DIV);
|
||||
lists.add(bi.mod(div).intValue());
|
||||
bi = bi.divide(div);
|
||||
}
|
||||
for (int i = lists.size() - 1; i >= 0; --i) {
|
||||
res.append((char) (int) (lists.get(i)));
|
||||
|
@ -98,13 +132,19 @@ public class BaseUtil {
|
|||
return res.toString();
|
||||
}
|
||||
|
||||
public static byte[] bigInteger2byteArray(BigInteger bigInteger) throws UnsupportedEncodingException {
|
||||
/**
|
||||
* Convert BigInteger to byte[].
|
||||
*
|
||||
* @param bigInteger the BigInteger object.
|
||||
* @return the byte[] object converted from BigInteger.
|
||||
*/
|
||||
public static byte[] bigInteger2byteArray(BigInteger bigInteger) {
|
||||
List<Integer> lists = new ArrayList<>();
|
||||
BigInteger bi = bigInteger;
|
||||
BigInteger DIV = BigInteger.valueOf(256);
|
||||
BigInteger div = BigInteger.valueOf(256);
|
||||
while (bi.compareTo(BigInteger.ZERO) > 0) {
|
||||
lists.add(bi.mod(DIV).intValue());
|
||||
bi = bi.divide(DIV);
|
||||
lists.add(bi.mod(div).intValue());
|
||||
bi = bi.divide(div);
|
||||
}
|
||||
byte[] res = new byte[lists.size()];
|
||||
for (int i = lists.size() - 1; i >= 0; --i) {
|
||||
|
@ -113,13 +153,19 @@ public class BaseUtil {
|
|||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Integer to byte[].
|
||||
*
|
||||
* @param num the Integer object.
|
||||
* @return the byte[] object converted from Integer.
|
||||
*/
|
||||
public static byte[] integer2byteArray(Integer num) {
|
||||
List<Integer> lists = new ArrayList<>();
|
||||
Integer bi = num;
|
||||
Integer DIV = 256;
|
||||
Integer div = 256;
|
||||
while (bi > 0) {
|
||||
lists.add(bi % DIV);
|
||||
bi = bi / DIV;
|
||||
lists.add(bi % div);
|
||||
bi = bi / div;
|
||||
}
|
||||
byte[] res = new byte[lists.size()];
|
||||
for (int i = lists.size() - 1; i >= 0; --i) {
|
||||
|
@ -128,8 +174,13 @@ public class BaseUtil {
|
|||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert byte[] to Integer.
|
||||
*
|
||||
* @param bytes the byte[] object.
|
||||
* @return the Integer object converted from byte[].
|
||||
*/
|
||||
public static Integer byteArray2Integer(byte[] bytes) {
|
||||
|
||||
Integer num = 0;
|
||||
for (int i = 0; i < bytes.length; ++i) {
|
||||
int intI = bytes[i];
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,9 +14,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.FLClientStatus;
|
||||
import com.mindspore.flclient.FLCommunication;
|
||||
|
@ -25,23 +29,27 @@ import com.mindspore.flclient.LocalFLParameter;
|
|||
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
||||
import com.mindspore.flclient.cipher.struct.EncryptShare;
|
||||
import com.mindspore.flclient.cipher.struct.NewArray;
|
||||
|
||||
import mindspore.schema.GetClientList;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.ReturnClientList;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.Arrays;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
|
||||
|
||||
/**
|
||||
* Define the serialization method, handle the response message returned from server for GetClientList request.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class ClientListReq {
|
||||
|
||||
private static final Logger LOGGER = Logger.getLogger(ClientListReq.class.toString());
|
||||
|
||||
private FLCommunication flCommunication;
|
||||
private String nextRequestTime;
|
||||
private FLParameter flParameter = FLParameter.getInstance();
|
||||
|
@ -64,34 +72,63 @@ public class ClientListReq {
|
|||
return retCode;
|
||||
}
|
||||
|
||||
public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
/**
|
||||
* Send serialized request message of GetClientList to server.
|
||||
*
|
||||
* @param iteration current iteration of federated learning task.
|
||||
* @param u3ClientList list of clients successfully requested in UpdateModel round.
|
||||
* @param decryptSecretsList list to store to decrypted secret fragments.
|
||||
* @param returnShareList List of returned secret fragments from server.
|
||||
* @param cuvKeys Keys used to decrypt secret fragments.
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
public FLClientStatus getClientList(int iteration, List<String> u3ClientList,
|
||||
List<DecryptShareSecrets> decryptSecretsList,
|
||||
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||
FlatBufferBuilder builder = new FlatBufferBuilder();
|
||||
int id = builder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
int time = builder.createString(dateTime);
|
||||
int clientListRoot = GetClientList.createGetClientList(builder, id, iteration, time);
|
||||
builder.finish(clientListRoot);
|
||||
byte[] msg = builder.sizedByteArray();
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg);
|
||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[getClientList] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(responseData)) {
|
||||
LOGGER.info(Common.addTag("[getClientList] the server is not ready now, need wait some time and " +
|
||||
"request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
LOGGER.info(Common.addTag("getClientList responseData size: " + responseData.length));
|
||||
ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer);
|
||||
FLClientStatus status = judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);
|
||||
} catch (IOException ex) {
|
||||
LOGGER.severe(Common.addTag("[getClientList] unsolved error code in getClientList: catch IOException: " +
|
||||
ex.getMessage()));
|
||||
retCode = ResponseCode.RequestError;
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||
/**
|
||||
* Analyze the serialization message returned from server and perform corresponding processing.
|
||||
*
|
||||
* @param bufData Serialized message returned from server.
|
||||
* @param u3ClientList list of clients successfully requested in UpdateModel round.
|
||||
* @param decryptSecretsList list to store decrypted secret fragments.
|
||||
* @param returnShareList List of returned secret fragments from server.
|
||||
* @param cuvKeys Keys used to decrypt secret fragments.
|
||||
* @return the status code corresponding to the response message.
|
||||
*/
|
||||
private FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList,
|
||||
List<DecryptShareSecrets> decryptSecretsList,
|
||||
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
|
@ -109,18 +146,15 @@ public class ClientListReq {
|
|||
String curFlId = bufData.clients(i);
|
||||
u3ClientList.add(curFlId);
|
||||
}
|
||||
try {
|
||||
decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
status = decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys);
|
||||
return status;
|
||||
case (ResponseCode.SucNotReady):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetClientList again!"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
|
||||
"GetClientList again!"));
|
||||
return FLClientStatus.WAIT;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob again"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob" +
|
||||
" again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
|
@ -128,36 +162,66 @@ public class ClientListReq {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is invalid: " + retCode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is " +
|
||||
"invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
public void decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) throws Exception {
|
||||
private FLClientStatus decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList,
|
||||
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||
decryptSecretsList.clear();
|
||||
int size = returnShareList.size();
|
||||
if (size <= 0) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <returnShareList> is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (cuvKeys.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <cuvKeys> is null"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
for (int i = 0; i < size; i++) {
|
||||
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
|
||||
EncryptShare encryptShare = returnShareList.get(i);
|
||||
String vFlID = encryptShare.getFlID();
|
||||
byte[] share = encryptShare.getShare().getArray();
|
||||
byte[] iVecIn = new byte[IVEC_LEN];
|
||||
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), iVecIn, "CBC");
|
||||
if (!cuvKeys.containsKey(vFlID)) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the key <vFlID> is not in map <cuvKeys> "));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), "CBC");
|
||||
byte[] decryptShare = aesEncrypt.decrypt(cuvKeys.get(vFlID), share);
|
||||
if (decryptShare == null || decryptShare.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[decryptSecretShares] the return byte[] is null, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
if (decryptShare.length < 4) {
|
||||
LOGGER.severe(Common.addTag("[decryptSecretShares] the returned decryptShare is not valid: length is " +
|
||||
"not right, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
int sSize = (int) decryptShare[0];
|
||||
int bSize = (int) decryptShare[1];
|
||||
int sIndexLen = (int) decryptShare[2];
|
||||
int bIndexLen = (int) decryptShare[3];
|
||||
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen));
|
||||
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen, 4 + sIndexLen + bIndexLen));
|
||||
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen, 4 + sIndexLen + bIndexLen + sSize);
|
||||
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize, 4 + sIndexLen + bIndexLen + sSize + bSize);
|
||||
if (decryptShare.length < (4 + sIndexLen + bIndexLen + sSize + bSize)) {
|
||||
LOGGER.severe(Common.addTag("[decryptSecretShares] the returned decryptShare is not valid: length is " +
|
||||
"not right, please check!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen,
|
||||
4 + sIndexLen + bIndexLen + sSize);
|
||||
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize,
|
||||
4 + sIndexLen + bIndexLen + sSize + bSize);
|
||||
NewArray<byte[]> sSkVu = new NewArray<>();
|
||||
sSkVu.setSize(sSize);
|
||||
sSkVu.setArray(sSkUv);
|
||||
NewArray bVu = new NewArray();
|
||||
bVu.setSize(bSize);
|
||||
bVu.setArray(bUv);
|
||||
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen));
|
||||
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen,
|
||||
4 + sIndexLen + bIndexLen));
|
||||
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
|
||||
decryptShareSecrets.setFlID(vFlID);
|
||||
decryptShareSecrets.setSSkVu(sSkVu);
|
||||
decryptShareSecrets.setBVu(bVu);
|
||||
|
@ -165,5 +229,6 @@ public class ClientListReq {
|
|||
decryptShareSecrets.setIndexB(bIndex);
|
||||
decryptSecretsList.add(decryptShareSecrets);
|
||||
}
|
||||
return FLClientStatus.SUCCESS;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,6 +16,10 @@
|
|||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import org.bouncycastle.crypto.digests.SHA256Digest;
|
||||
import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator;
|
||||
import org.bouncycastle.crypto.params.KeyParameter;
|
||||
|
@ -25,39 +28,80 @@ import org.bouncycastle.math.ec.rfc7748.X25519;
|
|||
import java.security.SecureRandom;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Generate public-private key pairs and DH Keys.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class KEYAgreement {
|
||||
private static final Logger LOGGER = Logger.getLogger(KEYAgreement.class.toString());
|
||||
private static final int PBKDF2_ITERATIONS = 10000;
|
||||
private static final int SALT_SIZE = 32;
|
||||
private static final int HASH_BIT_SIZE = 256;
|
||||
private static final int KEY_LEN = X25519.SCALAR_SIZE;
|
||||
|
||||
private SecureRandom random = new SecureRandom();
|
||||
private SecureRandom random = Common.getSecureRandom();
|
||||
|
||||
/**
|
||||
* Generate private Key.
|
||||
*
|
||||
* @return the private Key.
|
||||
*/
|
||||
public byte[] generatePrivateKey() {
|
||||
byte[] privateKey = new byte[KEY_LEN];
|
||||
X25519.generatePrivateKey(random, privateKey);
|
||||
return privateKey;
|
||||
}
|
||||
|
||||
public byte[] generatePublicKey(byte[] privatekey) {
|
||||
/**
|
||||
* Use private Key to generate public Key.
|
||||
*
|
||||
* @param privateKey the private Key.
|
||||
* @return the public Key.
|
||||
*/
|
||||
public byte[] generatePublicKey(byte[] privateKey) {
|
||||
if (privateKey == null || privateKey.length == 0) {
|
||||
LOGGER.severe(Common.addTag("privateKey is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
byte[] publicKey = new byte[KEY_LEN];
|
||||
X25519.generatePublicKey(privatekey, 0, publicKey, 0);
|
||||
X25519.generatePublicKey(privateKey, 0, publicKey, 0);
|
||||
return publicKey;
|
||||
}
|
||||
|
||||
public byte[] keyAgreement(byte[] privatekey, byte[] publicKey) {
|
||||
/**
|
||||
* Use private Key and public Key to generate DH Key.
|
||||
*
|
||||
* @param privateKey the private Key.
|
||||
* @param publicKey the public Key.
|
||||
* @return the DH Key.
|
||||
*/
|
||||
public byte[] keyAgreement(byte[] privateKey, byte[] publicKey) {
|
||||
if (privateKey == null || privateKey.length == 0) {
|
||||
LOGGER.severe(Common.addTag("privateKey is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
if (publicKey == null || publicKey.length == 0) {
|
||||
LOGGER.severe(Common.addTag("publicKey is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
byte[] secret = new byte[KEY_LEN];
|
||||
X25519.calculateAgreement(privatekey, 0, publicKey, 0, secret, 0);
|
||||
X25519.calculateAgreement(privateKey, 0, publicKey, 0, secret, 0);
|
||||
return secret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Encrypt DH Key.
|
||||
*
|
||||
* @param password the DH Key.
|
||||
* @param salt the salt value.
|
||||
* @return encrypted DH Key.
|
||||
*/
|
||||
public byte[] getEncryptedPassword(byte[] password, byte[] salt) {
|
||||
|
||||
byte[] saltB = new byte[SALT_SIZE];
|
||||
if (password == null || password.length == 0) {
|
||||
LOGGER.severe(Common.addTag("password is null"));
|
||||
return new byte[0];
|
||||
}
|
||||
PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest());
|
||||
gen.init(password, saltB, PBKDF2_ITERATIONS);
|
||||
byte[] dk = ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey();
|
||||
return dk;
|
||||
gen.init(password, salt, PBKDF2_ITERATIONS);
|
||||
return ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import java.security.SecureRandom;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* Define the basic method for generating pairwise mask and individual mask.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class Masking {
|
||||
private static final Logger LOGGER = Logger.getLogger(Masking.class.toString());
|
||||
|
||||
/**
|
||||
* Random generate RNG algorithm name.
|
||||
*/
|
||||
private static final String RNG_ALGORITHM = "SHA1PRNG";
|
||||
|
||||
/**
|
||||
* Generate individual mask.
|
||||
*
|
||||
* @param secret used to store individual mask.
|
||||
* @return the int value, 0 indicates success, -1 indicates failed .
|
||||
*/
|
||||
public int getRandomBytes(byte[] secret) {
|
||||
if (secret == null || secret.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[Masking] the input argument <secret> is null, please check!"));
|
||||
return -1;
|
||||
}
|
||||
SecureRandom secureRandom = Common.getSecureRandom();
|
||||
secureRandom.nextBytes(secret);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate pairwise mask.
|
||||
*
|
||||
* @param noise used to store pairwise mask.
|
||||
* @param length the length of individual mask.
|
||||
* @param seed the seed for generate pairwise mask.
|
||||
* @param iVec the IV value.
|
||||
* @return the int value, 0 indicates success, -1 indicates failed .
|
||||
*/
|
||||
public int getMasking(List<Float> noise, int length, byte[] seed, byte[] iVec) {
|
||||
if (length <= 0) {
|
||||
LOGGER.severe(Common.addTag("[Masking] the input argument <length> is not valid: <= 0, please check!"));
|
||||
return -1;
|
||||
}
|
||||
int intV = Integer.SIZE / 8;
|
||||
int size = length * intV;
|
||||
byte[] data = new byte[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
data[i] = 0;
|
||||
}
|
||||
AESEncrypt aesEncrypt = new AESEncrypt(seed, "CTR");
|
||||
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data, iVec);
|
||||
if (encryptCtr == null || encryptCtr.length == 0) {
|
||||
LOGGER.severe(Common.addTag("[Masking] the return byte[] is null, please check!"));
|
||||
return -1;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
int[] sub = new int[intV];
|
||||
for (int j = 0; j < 4; j++) {
|
||||
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
|
||||
}
|
||||
int subI = byte2int(sub, 4);
|
||||
if (subI == -1) {
|
||||
LOGGER.severe(Common.addTag("[Masking] the the returned <subI> is not valid: -1, please check!"));
|
||||
return -1;
|
||||
}
|
||||
Float fNoise = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
|
||||
noise.add(fNoise);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
private static int byte2int(int[] data, int number) {
|
||||
if (data.length < 4) {
|
||||
LOGGER.severe(Common.addTag("[Masking] the input argument <data> is not valid: length < 4, please check!"));
|
||||
return -1;
|
||||
}
|
||||
switch (number) {
|
||||
case 1:
|
||||
return (int) data[0];
|
||||
case 2:
|
||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
|
||||
case 3:
|
||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
|
||||
case 4:
|
||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
|
||||
| (data[3] << 24 & 0xff000000);
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,82 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
public class Random {
|
||||
/**
|
||||
* random generate RNG algorithm name
|
||||
*/
|
||||
private static final Logger LOGGER = Logger.getLogger(Random.class.toString());
|
||||
private static final String RNG_ALGORITHM = "SHA1PRNG";
|
||||
|
||||
private static final int RANDOM_LEN = 128 / 8;
|
||||
|
||||
public void getRandomBytes(byte[] secret) {
|
||||
try {
|
||||
SecureRandom secureRandom = SecureRandom.getInstance("SHA1PRNG");
|
||||
secureRandom.nextBytes(secret);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
public void randomAESCTR(List<Float> noise, int length, byte[] seed) throws Exception {
|
||||
int intV = Integer.SIZE / 8;
|
||||
int size = length * intV;
|
||||
byte[] data = new byte[size];
|
||||
for (int i = 0; i < size; i++) {
|
||||
data[i] = 0;
|
||||
}
|
||||
byte[] ivec = new byte[RANDOM_LEN];
|
||||
AESEncrypt aesEncrypt = new AESEncrypt(seed, ivec, "CTR");
|
||||
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data);
|
||||
for (int i = 0; i < length; i++) {
|
||||
int[] sub = new int[intV];
|
||||
for (int j = 0; j < 4; j++) {
|
||||
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
|
||||
}
|
||||
int subI = byte2int(sub, 4);
|
||||
|
||||
Float f = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
|
||||
noise.add(f);
|
||||
}
|
||||
}
|
||||
|
||||
public static int byte2int(int[] data, int n) {
|
||||
switch (n) {
|
||||
case 1:
|
||||
return (int) data[0];
|
||||
case 2:
|
||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
|
||||
case 3:
|
||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
|
||||
case 4:
|
||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
|
||||
| (data[3] << 24 & 0xff000000);
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,23 +16,33 @@
|
|||
|
||||
package com.mindspore.flclient.cipher;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
|
||||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
import com.mindspore.flclient.FLClientStatus;
|
||||
import com.mindspore.flclient.FLCommunication;
|
||||
import com.mindspore.flclient.FLParameter;
|
||||
import com.mindspore.flclient.LocalFLParameter;
|
||||
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
||||
import mindspore.schema.ClientShare;
|
||||
import mindspore.schema.ResponseCode;
|
||||
|
||||
import mindspore.schema.ClientShare;
|
||||
import mindspore.schema.ReconstructSecret;
|
||||
import mindspore.schema.ResponseCode;
|
||||
import mindspore.schema.SendReconstructSecret;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.time.LocalDateTime;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||
|
||||
/**
|
||||
* reconstruct secret request
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class ReconstructSecretReq {
|
||||
private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString());
|
||||
private FLCommunication flCommunication;
|
||||
|
@ -41,36 +51,44 @@ public class ReconstructSecretReq {
|
|||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||
private int retCode;
|
||||
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
public void setNextRequestTime(String nextRequestTime) {
|
||||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
|
||||
/**
|
||||
* reconstruct secret request
|
||||
*/
|
||||
public ReconstructSecretReq() {
|
||||
flCommunication = FLCommunication.getInstance();
|
||||
}
|
||||
|
||||
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList, List<String> u3ClientList, int iteration) {
|
||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
||||
/**
|
||||
* send secret shards to server
|
||||
*
|
||||
* @param decryptShareSecretsList secret shards list
|
||||
* @param u3ClientList u3 client list
|
||||
* @param iteration iter number
|
||||
* @return request result
|
||||
*/
|
||||
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList,
|
||||
List<String> u3ClientList, int iteration) {
|
||||
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
|
||||
FlatBufferBuilder builder = new FlatBufferBuilder();
|
||||
int desFlId = builder.createString(localFLParameter.getFlID());
|
||||
String dateTime = LocalDateTime.now().toString();
|
||||
Date date = new Date();
|
||||
long timestamp = date.getTime();
|
||||
String dateTime = String.valueOf(timestamp);
|
||||
int time = builder.createString(dateTime);
|
||||
int shareSecretsSize = decryptShareSecretsList.size();
|
||||
if (shareSecretsSize <= 0) {
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please waite."));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please " +
|
||||
"waite."));
|
||||
return FLClientStatus.FAILED;
|
||||
} else {
|
||||
int[] decryptShareList = new int[shareSecretsSize];
|
||||
for (int i = 0; i < shareSecretsSize; i++) {
|
||||
DecryptShareSecrets decryptShareSecrets = decryptShareSecretsList.get(i);
|
||||
if (decryptShareSecrets.getFlID() == null) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] get remote flID failed!"));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
|
||||
String srcFlId = decryptShareSecrets.getFlID();
|
||||
byte[] share;
|
||||
int index;
|
||||
|
@ -86,31 +104,33 @@ public class ReconstructSecretReq {
|
|||
int clientShare = ClientShare.createClientShare(builder, fbsSrcFlId, fbsShare, index);
|
||||
decryptShareList[i] = clientShare;
|
||||
}
|
||||
int reconstructShareSecrets = mindspore.schema.SendReconstructSecret.createReconstructSecretSharesVector(builder, decryptShareList);
|
||||
int reconstructSecretRoot = mindspore.schema.SendReconstructSecret.createSendReconstructSecret(builder, desFlId, reconstructShareSecrets, iteration, time);
|
||||
int reconstructShareSecrets = SendReconstructSecret.createReconstructSecretSharesVector(builder,
|
||||
decryptShareList);
|
||||
int reconstructSecretRoot = SendReconstructSecret.createSendReconstructSecret(builder, desFlId,
|
||||
reconstructShareSecrets, iteration, time);
|
||||
builder.finish(reconstructSecretRoot);
|
||||
byte[] msg = builder.sizedByteArray();
|
||||
try {
|
||||
byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg);
|
||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
||||
LOGGER.info(Common.addTag("[sendReconstructSecret] The cluster is in safemode, need wait some time and request again"));
|
||||
if (!Common.isSeverReady(responseData)) {
|
||||
LOGGER.info(Common.addTag("[sendReconstructSecret] the server is not ready now, need wait some " +
|
||||
"time and request again"));
|
||||
Common.sleep(SLEEP_TIME);
|
||||
nextRequestTime = "";
|
||||
return FLClientStatus.RESTART;
|
||||
}
|
||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||
mindspore.schema.ReconstructSecret reconstructSecretRsp = mindspore.schema.ReconstructSecret.getRootAsReconstructSecret(buffer);
|
||||
FLClientStatus status = judgeSendReconstructSecrets(reconstructSecretRsp);
|
||||
return status;
|
||||
} catch (Exception e) {
|
||||
ReconstructSecret reconstructSecretRsp = ReconstructSecret.getRootAsReconstructSecret(buffer);
|
||||
return judgeSendReconstructSecrets(reconstructSecretRsp);
|
||||
} catch (IOException ex) {
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] un solved error code in reconstruct"));
|
||||
e.printStackTrace();
|
||||
ex.printStackTrace();
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) {
|
||||
private FLClientStatus judgeSendReconstructSecrets(ReconstructSecret bufData) {
|
||||
retCode = bufData.retcode();
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||
|
@ -122,7 +142,8 @@ public class ReconstructSecretReq {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success"));
|
||||
return FLClientStatus.SUCCESS;
|
||||
case (ResponseCode.OutOfTime):
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request startFLJob again"));
|
||||
LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request " +
|
||||
"startFLJob again"));
|
||||
setNextRequestTime(bufData.nextReqTime());
|
||||
return FLClientStatus.RESTART;
|
||||
case (ResponseCode.RequestError):
|
||||
|
@ -130,8 +151,36 @@ public class ReconstructSecretReq {
|
|||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets"));
|
||||
return FLClientStatus.FAILED;
|
||||
default:
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is invalid: " + retCode));
|
||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is " +
|
||||
"invalid: " + retCode));
|
||||
return FLClientStatus.FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* get next request time
|
||||
*
|
||||
* @return next request time
|
||||
*/
|
||||
public String getNextRequestTime() {
|
||||
return nextRequestTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* set next request time
|
||||
*
|
||||
* @param nextRequestTime next request time
|
||||
*/
|
||||
public void setNextRequestTime(String nextRequestTime) {
|
||||
this.nextRequestTime = nextRequestTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* get retCode
|
||||
*
|
||||
* @return retCode
|
||||
*/
|
||||
public int getRetCode() {
|
||||
return retCode;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,115 +22,158 @@ import java.math.BigInteger;
|
|||
import java.util.Random;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
|
||||
/**
|
||||
* Define functions that for splitting secret and combining secret shards.
|
||||
*
|
||||
* @since 2021-06-30
|
||||
*/
|
||||
public class ShareSecrets {
|
||||
private static final Logger LOGGER = Logger.getLogger(ShareSecrets.class.toString());
|
||||
|
||||
public final class SecretShare {
|
||||
public SecretShare(final int num, final BigInteger share) {
|
||||
this.num = num;
|
||||
this.share = share;
|
||||
}
|
||||
private BigInteger prime;
|
||||
private final int minNum;
|
||||
private final int totalNum;
|
||||
private final Random random;
|
||||
|
||||
public int getNum() {
|
||||
return num;
|
||||
/**
|
||||
* Defines the constructor of the class ShareSecrets.
|
||||
*
|
||||
* @param minNum minimum number of fragments required to reconstruct a secret.
|
||||
* @param totalNum total clients number.
|
||||
*/
|
||||
public ShareSecrets(final int minNum, final int totalNum) {
|
||||
if (minNum <= 0) {
|
||||
LOGGER.severe(Common.addTag("the argument <k> is not valid: <= 0, it should be > 0"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
public BigInteger getShare() {
|
||||
return share;
|
||||
if (totalNum <= 0) {
|
||||
LOGGER.severe(Common.addTag("the argument <n> is not valid: <= 0, it should be > 0"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "SecretShare [num=" + num + ", share=" + share + "]";
|
||||
if (minNum > totalNum) {
|
||||
LOGGER.severe(Common.addTag("the argument <k, n> is not valid: k > n, it should k <= n"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
private final int num;
|
||||
private final BigInteger share;
|
||||
this.minNum = minNum;
|
||||
this.totalNum = totalNum;
|
||||
random = Common.getSecureRandom();
|
||||
}
|
||||
|
||||
public ShareSecrets(final int k, final int n) {
|
||||
this.k = k;
|
||||
this.n = n;
|
||||
|
||||
random = new Random();
|
||||
}
|
||||
|
||||
public SecretShare[] split(final byte[] bytes, byte[] primeByte) {
|
||||
/**
|
||||
* Splits a secret into a specified number of secret fragments.
|
||||
*
|
||||
* @param bytes the secret need to be split.
|
||||
* @param primeByte teh big prime number used to combine secret fragments.
|
||||
* @return the secret fragments.
|
||||
*/
|
||||
public SecretShares[] split(final byte[] bytes, byte[] primeByte) {
|
||||
if (bytes == null || bytes.length == 0) {
|
||||
LOGGER.severe(Common.addTag("the input argument <bytes> is null"));
|
||||
return new SecretShares[0];
|
||||
}
|
||||
if (primeByte == null || primeByte.length == 0) {
|
||||
LOGGER.severe(Common.addTag("the input argument <primeByte> is null"));
|
||||
return new SecretShares[0];
|
||||
}
|
||||
BigInteger secret = BaseUtil.byteArray2BigInteger(bytes);
|
||||
final int modLength = secret.bitLength() + 1;
|
||||
prime = BaseUtil.byteArray2BigInteger(primeByte);
|
||||
final BigInteger[] coeff = new BigInteger[k - 1];
|
||||
final BigInteger[] coefficient = new BigInteger[minNum - 1];
|
||||
|
||||
LOGGER.info(Common.addTag("Prime Number: " + prime));
|
||||
|
||||
for (int i = 0; i < k - 1; i++) {
|
||||
coeff[i] = randomZp(prime);
|
||||
LOGGER.info(Common.addTag("a" + (i + 1) + ": " + coeff[i]));
|
||||
for (int i = 0; i < minNum - 1; i++) {
|
||||
coefficient[i] = randomZp(prime);
|
||||
}
|
||||
|
||||
final SecretShare[] shares = new SecretShare[n];
|
||||
for (int i = 1; i <= n; i++) {
|
||||
BigInteger accum = secret;
|
||||
final SecretShares[] shares = new SecretShares[totalNum];
|
||||
for (int i = 1; i <= totalNum; i++) {
|
||||
BigInteger accumulate = secret;
|
||||
|
||||
for (int j = 1; j < k; j++) {
|
||||
final BigInteger t1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime);
|
||||
final BigInteger t2 = coeff[j - 1].multiply(t1).mod(prime);
|
||||
for (int j = 1; j < minNum; j++) {
|
||||
final BigInteger b1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime);
|
||||
final BigInteger b2 = coefficient[j - 1].multiply(b1).mod(prime);
|
||||
|
||||
accum = accum.add(t2).mod(prime);
|
||||
accumulate = accumulate.add(b2).mod(prime);
|
||||
}
|
||||
shares[i - 1] = new SecretShare(i, accum);
|
||||
LOGGER.info(Common.addTag("Share " + shares[i - 1]));
|
||||
shares[i - 1] = new SecretShares(i, accumulate);
|
||||
}
|
||||
return shares;
|
||||
}
|
||||
|
||||
public BigInteger getPrime() {
|
||||
return prime;
|
||||
}
|
||||
|
||||
public BigInteger combine(final SecretShare[] shares, final byte[] primeByte) {
|
||||
/**
|
||||
* Combine secret fragments.
|
||||
*
|
||||
* @param shares the secret fragments.
|
||||
* @param primeByte teh big prime number used to combine secret fragments.
|
||||
* @return the secrets combined by secret fragments.
|
||||
*/
|
||||
public BigInteger combine(final SecretShares[] shares, final byte[] primeByte) {
|
||||
if (shares == null || shares.length == 0) {
|
||||
LOGGER.severe(Common.addTag("the input argument <shares> is null"));
|
||||
return BigInteger.ZERO;
|
||||
}
|
||||
if (primeByte == null || primeByte.length == 0) {
|
||||
LOGGER.severe(Common.addTag("the input argument <primeByte> is null"));
|
||||
return BigInteger.ZERO;
|
||||
}
|
||||
BigInteger primeNum = BaseUtil.byteArray2BigInteger(primeByte);
|
||||
BigInteger accum = BigInteger.ZERO;
|
||||
for (int j = 0; j < k; j++) {
|
||||
BigInteger accumulate = BigInteger.ZERO;
|
||||
for (int j = 0; j < minNum; j++) {
|
||||
BigInteger num = BigInteger.ONE;
|
||||
BigInteger den = BigInteger.ONE;
|
||||
|
||||
BigInteger tmp;
|
||||
|
||||
for (int m = 0; m < k; m++) {
|
||||
for (int m = 0; m < minNum; m++) {
|
||||
if (j != m) {
|
||||
num = num.multiply(BigInteger.valueOf(shares[m].getNum())).mod(primeNum);
|
||||
tmp = BigInteger.valueOf(shares[j].getNum()).multiply(BigInteger.valueOf(-1));
|
||||
tmp = BigInteger.valueOf(shares[m].getNum()).add(tmp).mod(primeNum);
|
||||
num = num.multiply(BigInteger.valueOf(shares[m].getNumber())).mod(primeNum);
|
||||
tmp = BigInteger.valueOf(shares[j].getNumber()).multiply(BigInteger.valueOf(-1));
|
||||
tmp = BigInteger.valueOf(shares[m].getNumber()).add(tmp).mod(primeNum);
|
||||
den = den.multiply(tmp).mod(primeNum);
|
||||
}
|
||||
}
|
||||
final BigInteger value = shares[j].getShare();
|
||||
final BigInteger value = shares[j].getShares();
|
||||
|
||||
tmp = den.modInverse(primeNum);
|
||||
tmp = tmp.multiply(num).mod(primeNum);
|
||||
tmp = tmp.multiply(value).mod(primeNum);
|
||||
accum = accum.add(tmp).mod(primeNum);
|
||||
LOGGER.info(Common.addTag("value: " + value + ", tmp: " + tmp + ", accum: " + accum));
|
||||
accumulate = accumulate.add(tmp).mod(primeNum);
|
||||
}
|
||||
LOGGER.info(Common.addTag("The secret is: " + accum));
|
||||
return accum;
|
||||
return accumulate;
|
||||
}
|
||||
|
||||
private BigInteger randomZp(final BigInteger p) {
|
||||
private BigInteger randomZp(final BigInteger num) {
|
||||
while (true) {
|
||||
final BigInteger r = new BigInteger(p.bitLength(), random);
|
||||
if (r.compareTo(BigInteger.ZERO) > 0 && r.compareTo(p) < 0) {
|
||||
return r;
|
||||
final BigInteger rand = new BigInteger(num.bitLength(), random);
|
||||
if (rand.compareTo(BigInteger.ZERO) > 0 && rand.compareTo(num) < 0) {
|
||||
return rand;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private BigInteger prime;
|
||||
private final int k;
|
||||
private final int n;
|
||||
private final Random random;
|
||||
private final int SECRET_MAX_LEN = 32;
|
||||
/**
|
||||
* Define the structure for store secret fragments.
|
||||
*/
|
||||
public final class SecretShares {
|
||||
private final int number;
|
||||
private final BigInteger share;
|
||||
|
||||
public SecretShares(final int number, final BigInteger share) {
|
||||
this.number = number;
|
||||
this.share = share;
|
||||
}
|
||||
|
||||
public int getNumber() {
|
||||
return number;
|
||||
}
|
||||
|
||||
public BigInteger getShares() {
|
||||
return share;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "SecretShares [number=" + number + ", share=" + share + "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,33 +16,114 @@
|
|||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* public key class of secure aggregation
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class ClientPublicKey {
|
||||
private static final Logger LOGGER = Logger.getLogger(ClientPublicKey.class.toString());
|
||||
private String flID;
|
||||
private NewArray<byte[]> cPK;
|
||||
private NewArray<byte[]> sPk;
|
||||
private NewArray<byte[]> pwIv;
|
||||
private NewArray<byte[]> pwSalt;
|
||||
|
||||
/**
|
||||
* get client's flID
|
||||
*
|
||||
* @return flID of this client
|
||||
*/
|
||||
public String getFlID() {
|
||||
if (flID == null || flID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of <flID> is null, please set it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flID;
|
||||
}
|
||||
|
||||
/**
|
||||
* set client's flID
|
||||
*
|
||||
* @param flID hash value used for identify client
|
||||
*/
|
||||
public void setFlID(String flID) {
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
/**
|
||||
* get CPK of secure aggregation
|
||||
*
|
||||
* @return CPK of secure aggregation
|
||||
*/
|
||||
public NewArray<byte[]> getCPK() {
|
||||
return cPK;
|
||||
}
|
||||
|
||||
/**
|
||||
* set CPK of secure aggregation
|
||||
*
|
||||
* @param cPK public key used for encryption
|
||||
*/
|
||||
public void setCPK(NewArray<byte[]> cPK) {
|
||||
this.cPK = cPK;
|
||||
}
|
||||
|
||||
/**
|
||||
* get SPK of secure aggregation
|
||||
*
|
||||
* @return SPK of secure aggregation
|
||||
*/
|
||||
public NewArray<byte[]> getSPK() {
|
||||
return sPk;
|
||||
}
|
||||
|
||||
/**
|
||||
* set SPK of secure aggregation
|
||||
*
|
||||
* @param sPk public key used for encryption
|
||||
*/
|
||||
public void setSPK(NewArray<byte[]> sPk) {
|
||||
this.sPk = sPk;
|
||||
}
|
||||
|
||||
/**
|
||||
* get the IV value used for pairwise encrypt
|
||||
*
|
||||
* @return the IV value used for pairwise encrypt
|
||||
*/
|
||||
public NewArray<byte[]> getPwIv() {
|
||||
return pwIv;
|
||||
}
|
||||
|
||||
/**
|
||||
* set the IV value used for pairwise encrypt
|
||||
*
|
||||
* @param pwIv IV value used for pairwise encrypt
|
||||
*/
|
||||
public void setPwIv(NewArray<byte[]> pwIv) {
|
||||
this.pwIv = pwIv;
|
||||
}
|
||||
|
||||
/**
|
||||
* get salt value for secure aggregation
|
||||
*
|
||||
* @return salt value for secure aggregation
|
||||
*/
|
||||
public NewArray<byte[]> getPwSalt() {
|
||||
return pwSalt;
|
||||
}
|
||||
|
||||
/**
|
||||
* set salt value for secure aggregation
|
||||
*
|
||||
* @param pwSalt salt value for secure aggregation
|
||||
*/
|
||||
public void setPwSalt(NewArray<byte[]> pwSalt) {
|
||||
this.pwSalt = pwSalt;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,49 +16,114 @@
|
|||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* class used for set and get decryption shards
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class DecryptShareSecrets {
|
||||
private static final Logger LOGGER = Logger.getLogger(DecryptShareSecrets.class.toString());
|
||||
private String flID;
|
||||
private NewArray<byte[]> sSkVu;
|
||||
private NewArray<byte[]> bVu;
|
||||
private int sIndex;
|
||||
private int indexB;
|
||||
|
||||
/**
|
||||
* get flID of client
|
||||
*
|
||||
* @return flID of this client
|
||||
*/
|
||||
public String getFlID() {
|
||||
if (flID == null || flID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
|
||||
"use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flID;
|
||||
}
|
||||
|
||||
/**
|
||||
* set flID for this client
|
||||
*
|
||||
* @param flID hash value used for identify client
|
||||
*/
|
||||
public void setFlID(String flID) {
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
/**
|
||||
* get secret key shards
|
||||
*
|
||||
* @return secret key shards
|
||||
*/
|
||||
public NewArray<byte[]> getSSkVu() {
|
||||
return sSkVu;
|
||||
}
|
||||
|
||||
/**
|
||||
* set secret key shards
|
||||
*
|
||||
* @param sSkVu secret key shards
|
||||
*/
|
||||
public void setSSkVu(NewArray<byte[]> sSkVu) {
|
||||
this.sSkVu = sSkVu;
|
||||
}
|
||||
|
||||
/**
|
||||
* get bu shards
|
||||
*
|
||||
* @return bu shards
|
||||
*/
|
||||
public NewArray<byte[]> getBVu() {
|
||||
return bVu;
|
||||
}
|
||||
|
||||
/**
|
||||
* set bu shards
|
||||
*
|
||||
* @param bVu bu shards used for secure aggregation
|
||||
*/
|
||||
public void setBVu(NewArray<byte[]> bVu) {
|
||||
this.bVu = bVu;
|
||||
}
|
||||
|
||||
/**
|
||||
* get index of secret shards
|
||||
*
|
||||
* @return index of secret shards
|
||||
*/
|
||||
public int getSIndex() {
|
||||
return sIndex;
|
||||
}
|
||||
|
||||
/**
|
||||
* set index of secret shards
|
||||
*
|
||||
* @param sIndex index of secret shards
|
||||
*/
|
||||
public void setSIndex(int sIndex) {
|
||||
this.sIndex = sIndex;
|
||||
}
|
||||
|
||||
/**
|
||||
* get index of bu shards
|
||||
*
|
||||
* @return index of bu shards
|
||||
*/
|
||||
public int getIndexB() {
|
||||
return indexB;
|
||||
}
|
||||
|
||||
/**
|
||||
* set index of bu shards
|
||||
*
|
||||
* @param indexB index of bu shards
|
||||
*/
|
||||
public void setIndexB(int indexB) {
|
||||
this.indexB = indexB;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,22 +16,57 @@
|
|||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* class used for encrypt shares of secret
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class EncryptShare {
|
||||
private static final Logger LOGGER = Logger.getLogger(DecryptShareSecrets.class.toString());
|
||||
private String flID;
|
||||
private NewArray<byte[]> share;
|
||||
|
||||
/**
|
||||
* get client's flID
|
||||
*
|
||||
* @return flID of this client
|
||||
*/
|
||||
public String getFlID() {
|
||||
if (flID == null || flID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
|
||||
"use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flID;
|
||||
}
|
||||
|
||||
/**
|
||||
* set client's flID
|
||||
*
|
||||
* @param flID hash value used for identify client
|
||||
*/
|
||||
public void setFlID(String flID) {
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
/**
|
||||
* get secret share
|
||||
*
|
||||
* @return secret share
|
||||
*/
|
||||
public NewArray<byte[]> getShare() {
|
||||
return share;
|
||||
}
|
||||
|
||||
/**
|
||||
* set secret share
|
||||
*
|
||||
* @param share secret share
|
||||
*/
|
||||
public void setShare(NewArray<byte[]> share) {
|
||||
this.share = share;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,24 +16,50 @@
|
|||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
/**
|
||||
* class used define new array type
|
||||
*
|
||||
* @param <T> an array
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class NewArray<T> {
|
||||
private int size;
|
||||
private T array;
|
||||
|
||||
/**
|
||||
* get array size
|
||||
*
|
||||
* @return array size
|
||||
*/
|
||||
public int getSize() {
|
||||
return size;
|
||||
}
|
||||
|
||||
/**
|
||||
* set array size
|
||||
*
|
||||
* @param size array size
|
||||
*/
|
||||
public void setSize(int size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
/**
|
||||
* get array
|
||||
*
|
||||
* @return an array
|
||||
*/
|
||||
public T getArray() {
|
||||
return array;
|
||||
}
|
||||
|
||||
/**
|
||||
* set array
|
||||
*
|
||||
* @param array input
|
||||
*/
|
||||
public void setArray(T array) {
|
||||
this.array = array;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,31 +16,75 @@
|
|||
|
||||
package com.mindspore.flclient.cipher.struct;
|
||||
|
||||
import com.mindspore.flclient.Common;
|
||||
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/**
|
||||
* share secret class
|
||||
*
|
||||
* @since 2021-8-27
|
||||
*/
|
||||
public class ShareSecret {
|
||||
private static final Logger LOGGER = Logger.getLogger(ShareSecret.class.toString());
|
||||
private String flID;
|
||||
private NewArray<byte[]> share;
|
||||
private int index;
|
||||
|
||||
/**
|
||||
* get client's flID
|
||||
*
|
||||
* @return flID of this client
|
||||
*/
|
||||
public String getFlID() {
|
||||
if (flID == null || flID.isEmpty()) {
|
||||
LOGGER.severe(Common.addTag("[ShareSecret] the parameter of <flID> is null, please set it before use"));
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return flID;
|
||||
}
|
||||
|
||||
/**
|
||||
* set flID for this client
|
||||
*
|
||||
* @param flID hash value used for identify client
|
||||
*/
|
||||
public void setFlID(String flID) {
|
||||
this.flID = flID;
|
||||
}
|
||||
|
||||
/**
|
||||
* get secret share
|
||||
*
|
||||
* @return secret share
|
||||
*/
|
||||
public NewArray<byte[]> getShare() {
|
||||
return share;
|
||||
}
|
||||
|
||||
/**
|
||||
* set secret share
|
||||
*
|
||||
* @param share secret shares
|
||||
*/
|
||||
public void setShare(NewArray<byte[]> share) {
|
||||
this.share = share;
|
||||
}
|
||||
|
||||
/**
|
||||
* get secret index
|
||||
*
|
||||
* @return secret index
|
||||
*/
|
||||
public int getIndex() {
|
||||
return index;
|
||||
}
|
||||
|
||||
/**
|
||||
* set secret index
|
||||
*
|
||||
* @param index secret index
|
||||
*/
|
||||
public void setIndex(int index) {
|
||||
this.index = index;
|
||||
}
|
||||
|
|
|
@ -31,6 +31,8 @@ table ClientPublicKeys {
|
|||
fl_id:string;
|
||||
c_pk:[ubyte];
|
||||
s_pk: [ubyte];
|
||||
pw_iv: [ubyte];
|
||||
pw_salt: [ubyte];
|
||||
}
|
||||
|
||||
table ClientShare {
|
||||
|
@ -45,6 +47,9 @@ table RequestExchangeKeys{
|
|||
s_pk:[ubyte];
|
||||
iteration:int;
|
||||
timestamp:string;
|
||||
ind_iv:[ubyte];
|
||||
pw_iv:[ubyte];
|
||||
pw_salt:[ubyte];
|
||||
}
|
||||
|
||||
table ResponseExchangeKeys{
|
||||
|
|
Loading…
Reference in New Issue