forked from mindspore-Ecosystem/mindspore
Fix security problems and code-check problems for federated's secure aggregation
fix security check problems for flclient
This commit is contained in:
parent
0abff9ad65
commit
7d9dd343f3
|
@ -39,7 +39,7 @@ if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/key_agreement.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/key_agreement.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/random.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/masking.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/secret_sharing.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/secret_sharing.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_init.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_init.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_keys.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_keys.cc")
|
||||||
|
|
|
@ -22,26 +22,28 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
|
bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
|
||||||
bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_initial_client_cnt,
|
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
|
||||||
size_t cipher_exchange_secrets_cnt, size_t cipher_share_secrets_cnt,
|
|
||||||
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
|
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
|
||||||
size_t cipher_reconstruct_secrets_up_cnt) {
|
size_t cipher_reconstruct_secrets_up_cnt) {
|
||||||
MS_LOG(INFO) << "CipherInit::Init START";
|
MS_LOG(INFO) << "CipherInit::Init START";
|
||||||
int return_num = 0;
|
if (memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN) != 0) {
|
||||||
return_num = memcpy_s(publicparam_.p, SECRET_MAX_LEN, param.p, SECRET_MAX_LEN);
|
MS_LOG(ERROR) << "CipherInit::memory copy failed.";
|
||||||
if (return_num != 0) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
publicparam_.g = param.g;
|
publicparam_.g = param.g;
|
||||||
publicparam_.t = param.t;
|
publicparam_.t = param.t;
|
||||||
secrets_minnums_ = param.t;
|
secrets_minnums_ = param.t;
|
||||||
client_num_need_ = cipher_initial_client_cnt;
|
|
||||||
featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float);
|
featuremap_ = fl::server::ModelStore::GetInstance().model_size() / sizeof(float);
|
||||||
share_clients_num_need_ = cipher_share_secrets_cnt;
|
|
||||||
reconstruct_clients_num_need_ = cipher_reconstruct_secrets_down_cnt + 1;
|
exchange_key_threshold = cipher_exchange_keys_cnt;
|
||||||
get_model_num_need_ = cipher_get_clientlist_cnt;
|
get_key_threshold = cipher_get_keys_cnt;
|
||||||
|
share_secrets_threshold = cipher_share_secrets_cnt;
|
||||||
|
get_secrets_threshold = cipher_get_secrets_cnt;
|
||||||
|
client_list_threshold = cipher_get_clientlist_cnt;
|
||||||
|
reconstruct_secrets_threshold = cipher_reconstruct_secrets_up_cnt;
|
||||||
|
|
||||||
time_out_mutex_ = time_out_mutex;
|
time_out_mutex_ = time_out_mutex;
|
||||||
publicparam_.dp_eps = param.dp_eps;
|
publicparam_.dp_eps = param.dp_eps;
|
||||||
publicparam_.dp_delta = param.dp_delta;
|
publicparam_.dp_delta = param.dp_delta;
|
||||||
|
@ -62,10 +64,12 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
||||||
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
|
MS_LOG(ERROR) << "Cipher Param Update is invalid.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << " CipherInit client_num_need_ : " << client_num_need_;
|
MS_LOG(INFO) << " CipherInit exchange_key_threshold : " << exchange_key_threshold;
|
||||||
MS_LOG(INFO) << " CipherInit share_clients_num_need_ : " << share_clients_num_need_;
|
MS_LOG(INFO) << " CipherInit get_key_threshold : " << get_key_threshold;
|
||||||
MS_LOG(INFO) << " CipherInit reconstruct_clients_num_need_ : " << reconstruct_clients_num_need_;
|
MS_LOG(INFO) << " CipherInit share_secrets_threshold : " << share_secrets_threshold;
|
||||||
MS_LOG(INFO) << " CipherInit get_model_num_need_ : " << get_model_num_need_;
|
MS_LOG(INFO) << " CipherInit get_secrets_threshold : " << get_secrets_threshold;
|
||||||
|
MS_LOG(INFO) << " CipherInit client_list_threshold : " << client_list_threshold;
|
||||||
|
MS_LOG(INFO) << " CipherInit reconstruct_secrets_threshold : " << reconstruct_secrets_threshold;
|
||||||
MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_;
|
MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_;
|
||||||
if (!Check_Parames()) {
|
if (!Check_Parames()) {
|
||||||
MS_LOG(ERROR) << "Cipher parameters are illegal.";
|
MS_LOG(ERROR) << "Cipher parameters are illegal.";
|
||||||
|
@ -82,11 +86,10 @@ bool CipherInit::Check_Parames() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (share_clients_num_need_ < reconstruct_clients_num_need_) {
|
if (share_secrets_threshold < reconstruct_secrets_threshold) {
|
||||||
MS_LOG(ERROR)
|
MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger "
|
||||||
<< "reconstruct_clients_num_need (which is reconstruct_secrets_threshold + 1) should not be larger "
|
"than share_secrets_threshold, but got they are:"
|
||||||
"than share_clients_num_need (which is start_fl_job_threshold*share_secrets_ratio), but got they are:"
|
<< reconstruct_secrets_threshold << ", " << share_secrets_threshold;
|
||||||
<< reconstruct_clients_num_need_ << ", " << share_clients_num_need_;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,17 +29,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
|
|
||||||
template <typename T1>
|
|
||||||
bool CreateArray(std::vector<T1> *newData, const flatbuffers::Vector<T1> &fbs_arr) {
|
|
||||||
size_t size = newData->size();
|
|
||||||
size_t size_fbs_arr = fbs_arr.size();
|
|
||||||
if (size != size_fbs_arr) return false;
|
|
||||||
for (size_t i = 0; i < size; ++i) {
|
|
||||||
newData->at(i) = fbs_arr.Get(i);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialization of secure aggregation.
|
// Initialization of secure aggregation.
|
||||||
class CipherInit {
|
class CipherInit {
|
||||||
public:
|
public:
|
||||||
|
@ -49,9 +38,10 @@ class CipherInit {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the parameters of the secure aggregation.
|
// Initialize the parameters of the secure aggregation.
|
||||||
bool Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_initial_client_cnt,
|
bool Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
|
||||||
size_t cipher_exchange_secrets_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_clientlist_cnt,
|
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
|
||||||
size_t cipher_reconstruct_secrets_down_cnt, size_t cipher_reconstruct_secrets_up_cnt);
|
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
|
||||||
|
size_t cipher_reconstruct_secrets_up_cnt);
|
||||||
|
|
||||||
// Check whether the parameters are valid.
|
// Check whether the parameters are valid.
|
||||||
bool Check_Parames();
|
bool Check_Parames();
|
||||||
|
@ -59,10 +49,12 @@ class CipherInit {
|
||||||
// Get public params. which is given to start fl job thread.
|
// Get public params. which is given to start fl job thread.
|
||||||
CipherPublicPara *GetPublicParams() { return &publicparam_; }
|
CipherPublicPara *GetPublicParams() { return &publicparam_; }
|
||||||
|
|
||||||
size_t share_clients_num_need_; // the minimum number of clients to share secrets.
|
size_t share_secrets_threshold; // the minimum number of clients to share secret fragments.
|
||||||
size_t reconstruct_clients_num_need_; // the minimum number of clients to reconstruct secret mask.
|
size_t get_secrets_threshold; // the minimum number of clients to get secret fragments.
|
||||||
size_t client_num_need_; // the minimum number of clients to update model.
|
size_t reconstruct_secrets_threshold; // the minimum number of clients to reconstruct secret mask.
|
||||||
size_t get_model_num_need_; // the minimum number of clients to get model.
|
size_t exchange_key_threshold; // the minimum number of clients to send public keys.
|
||||||
|
size_t get_key_threshold; // the minimum number of clients to get public keys.
|
||||||
|
size_t client_list_threshold; // the minimum number of clients to get update model client list.
|
||||||
|
|
||||||
size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask.
|
size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask.
|
||||||
size_t featuremap_; // the size of data to deal.
|
size_t featuremap_; // the size of data to deal.
|
||||||
|
|
|
@ -21,203 +21,170 @@ namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
|
bool CipherKeys::GetKeys(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::GetExchangeKeys *get_exchange_keys_req,
|
const schema::GetExchangeKeys *get_exchange_keys_req,
|
||||||
const std::shared_ptr<fl::server::FBBuilder> &get_exchange_keys_resp_builder) {
|
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
|
||||||
MS_LOG(INFO) << "CipherMgr::GetKeys START";
|
MS_LOG(INFO) << "CipherMgr::GetKeys START";
|
||||||
if (get_exchange_keys_req == nullptr || get_exchange_keys_resp_builder == nullptr) {
|
if (get_exchange_keys_req == nullptr) {
|
||||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
|
MS_LOG(ERROR) << "Request is nullptr";
|
||||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false);
|
BuildGetKeysRsp(fbb, schema::ResponseCode_SystemError, cur_iterator, next_req_time, false);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// get clientlist from memory server.
|
// get clientlist from memory server.
|
||||||
std::vector<std::string> clients;
|
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
|
||||||
|
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||||
|
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &clients);
|
size_t cur_exchange_clients_num = client_public_keys.size();
|
||||||
|
|
||||||
size_t cur_clients_num = clients.size();
|
|
||||||
std::string fl_id = get_exchange_keys_req->fl_id()->str();
|
std::string fl_id = get_exchange_keys_req->fl_id()->str();
|
||||||
|
|
||||||
if (find(clients.begin(), clients.end(), fl_id) == clients.end()) {
|
if (cur_exchange_clients_num < cipher_init_->exchange_key_threshold) {
|
||||||
MS_LOG(INFO) << "The fl_id is not in clients.";
|
MS_LOG(INFO) << "The server is not ready yet: cur_exchangekey_clients_num < exchange_key_threshold";
|
||||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
|
MS_LOG(INFO) << "cur_exchangekey_clients_num : " << cur_exchange_clients_num
|
||||||
|
<< ", exchange_key_threshold : " << cipher_init_->exchange_key_threshold;
|
||||||
|
BuildGetKeysRsp(fbb, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (cur_clients_num < cipher_init_->client_num_need_) {
|
|
||||||
MS_LOG(INFO) << "The server is not ready yet: cur_clients_num < client_num_need";
|
if (client_public_keys.find(fl_id) == client_public_keys.end()) {
|
||||||
MS_LOG(INFO) << "cur_clients_num : " << cur_clients_num << ", client_num_need : " << cipher_init_->client_num_need_;
|
MS_LOG(INFO) << "Get keys: the fl_id: " << fl_id << "is not in exchange keys clients.";
|
||||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SucNotReady, cur_iterator, next_req_time, false);
|
BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, cur_iterator, next_req_time, false);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ret = cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetKeysClientList, fl_id);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(ERROR) << "update get keys clients failed";
|
||||||
|
BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, cur_iterator, next_req_time, false);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "GetKeys client list: ";
|
MS_LOG(INFO) << "GetKeys client list: ";
|
||||||
for (size_t i = 0; i < clients.size(); i++) {
|
BuildGetKeysRsp(fbb, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
|
||||||
MS_LOG(INFO) << "fl_id: " << clients[i];
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool flag =
|
|
||||||
BuildGetKeys(get_exchange_keys_resp_builder, schema::ResponseCode_SUCCEED, cur_iterator, next_req_time, true);
|
|
||||||
return flag;
|
|
||||||
} // namespace armour
|
|
||||||
|
|
||||||
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
bool CipherKeys::ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||||
const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder) {
|
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
|
||||||
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
|
MS_LOG(INFO) << "CipherMgr::ExchangeKeys START";
|
||||||
// step 0: judge if the input param is legal.
|
// step 0: judge if the input param is legal.
|
||||||
if (exchange_keys_req == nullptr || exchange_keys_resp_builder == nullptr) {
|
if (exchange_keys_req == nullptr) {
|
||||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
|
std::string reason = "Request is nullptr";
|
||||||
std::string reason = "Request is nullptr or Response builder is nullptr.";
|
MS_LOG(ERROR) << reason;
|
||||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time,
|
BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason, next_req_time, cur_iterator);
|
||||||
cur_iterator);
|
return false;
|
||||||
|
}
|
||||||
|
std::string fl_id = exchange_keys_req->fl_id()->str();
|
||||||
|
mindspore::fl::PBMetadata device_metas =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxDeviceMetas);
|
||||||
|
mindspore::fl::FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
|
||||||
|
MS_LOG(INFO) << "exchange key for fl id " << fl_id;
|
||||||
|
if (fl_id_to_meta.fl_id_to_meta().count(fl_id) == 0) {
|
||||||
|
std::string reason = "devices_meta for " + fl_id + " is not set. Please retry later.";
|
||||||
|
BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// step 1: get clientlist and client keys from memory server.
|
// step 1: get clientlist and client keys from memory server.
|
||||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
|
||||||
std::vector<std::string> client_list;
|
std::vector<std::string> client_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
|
||||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||||
|
|
||||||
// step2: process new item data. and update new item data to memory server.
|
// step2: process new item data. and update new item data to memory server.
|
||||||
size_t cur_clients_num = client_list.size();
|
size_t cur_clients_num = client_list.size();
|
||||||
size_t cur_clients_has_keys_num = record_public_keys.size();
|
size_t cur_clients_has_keys_num = client_public_keys.size();
|
||||||
if (cur_clients_num != cur_clients_has_keys_num) {
|
if (cur_clients_num != cur_clients_has_keys_num) {
|
||||||
std::string reason = "client num and keys num are not equal.";
|
std::string reason = "client num and keys num are not equal.";
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(WARNING) << reason;
|
||||||
MS_LOG(ERROR) << "cur_clients_num is " << cur_clients_num << ". cur_clients_has_keys_num is "
|
MS_LOG(WARNING) << "cur_clients_num is " << cur_clients_num << ". cur_clients_has_keys_num is "
|
||||||
<< cur_clients_has_keys_num;
|
<< cur_clients_has_keys_num;
|
||||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, reason, next_req_time,
|
|
||||||
cur_iterator);
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
MS_LOG(WARNING) << "exchange_key_threshold " << cipher_init_->exchange_key_threshold << ". cur_clients_num "
|
||||||
|
<< cur_clients_num << ". cur_clients_keys_num " << cur_clients_has_keys_num;
|
||||||
|
|
||||||
MS_LOG(INFO) << "client_num_need_ " << cipher_init_->client_num_need_ << ". cur_clients_num " << cur_clients_num;
|
if (client_public_keys.find(fl_id) != client_public_keys.end()) { // the client already exists, return false.
|
||||||
std::string fl_id = exchange_keys_req->fl_id()->str();
|
MS_LOG(ERROR) << "The server has received the request, please do not request again.";
|
||||||
if (cur_clients_num >= cipher_init_->client_num_need_) { // the client num is enough, return false.
|
BuildExchangeKeysRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||||
MS_LOG(ERROR) << "The server has received enough requests and refuse this request.";
|
|
||||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime,
|
|
||||||
"The server has received enough requests and refuse this request.", next_req_time,
|
|
||||||
cur_iterator);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (record_public_keys.find(fl_id) != record_public_keys.end()) { // the client already exists, return false.
|
|
||||||
MS_LOG(INFO) << "The server has received the request, please do not request again.";
|
|
||||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
|
|
||||||
"The server has received the request, please do not request again.", next_req_time,
|
"The server has received the request, please do not request again.", next_req_time,
|
||||||
cur_iterator);
|
cur_iterator);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gets the members of the deserialized data : exchange_keys_req
|
|
||||||
auto fbs_cpk = exchange_keys_req->c_pk();
|
|
||||||
size_t cpk_len = fbs_cpk->size();
|
|
||||||
auto fbs_spk = exchange_keys_req->s_pk();
|
|
||||||
size_t spk_len = fbs_spk->size();
|
|
||||||
|
|
||||||
// transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
|
|
||||||
std::vector<std::vector<unsigned char>> cur_public_key;
|
|
||||||
std::vector<unsigned char> cpk(cpk_len);
|
|
||||||
std::vector<unsigned char> spk(spk_len);
|
|
||||||
bool ret_create_code_cpk = CreateArray<unsigned char>(&cpk, *fbs_cpk);
|
|
||||||
bool ret_create_code_spk = CreateArray<unsigned char>(&spk, *fbs_spk);
|
|
||||||
if (!(ret_create_code_cpk && ret_create_code_spk)) {
|
|
||||||
MS_LOG(ERROR) << "create cur_public_key failed";
|
|
||||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, "update key or client failed",
|
|
||||||
next_req_time, cur_iterator);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
cur_public_key.push_back(cpk);
|
|
||||||
cur_public_key.push_back(spk);
|
|
||||||
|
|
||||||
bool retcode_key =
|
bool retcode_key =
|
||||||
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, fl_id, cur_public_key);
|
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
|
||||||
bool retcode_client =
|
bool retcode_client =
|
||||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
|
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
|
||||||
if (retcode_key && retcode_client) {
|
if (retcode_key && retcode_client) {
|
||||||
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
|
MS_LOG(INFO) << "The client " << fl_id << " CipherMgr::ExchangeKeys Success";
|
||||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_SUCCEED,
|
BuildExchangeKeysRsp(fbb, schema::ResponseCode_SUCCEED, "Success, but the server is not ready yet.", next_req_time,
|
||||||
"Success, but the server is not ready yet.", next_req_time, cur_iterator);
|
cur_iterator);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "update key or client failed";
|
MS_LOG(ERROR) << "update key or client failed";
|
||||||
BuildExchangeKeysRsp(exchange_keys_resp_builder, schema::ResponseCode_OutOfTime, "update key or client failed",
|
BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "update key or client failed", next_req_time,
|
||||||
next_req_time, cur_iterator);
|
cur_iterator);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherKeys::BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder,
|
void CipherKeys::BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||||
const schema::ResponseCode retcode, const std::string &reason,
|
const std::string &reason, const std::string &next_req_time,
|
||||||
const std::string &next_req_time, const int iteration) {
|
const int iteration) {
|
||||||
auto rsp_reason = exchange_keys_resp_builder->CreateString(reason);
|
auto rsp_reason = fbb->CreateString(reason);
|
||||||
auto rsp_next_req_time = exchange_keys_resp_builder->CreateString(next_req_time);
|
auto rsp_next_req_time = fbb->CreateString(next_req_time);
|
||||||
schema::ResponseExchangeKeysBuilder rsp_builder(*(exchange_keys_resp_builder.get()));
|
|
||||||
|
schema::ResponseExchangeKeysBuilder rsp_builder(*(fbb.get()));
|
||||||
rsp_builder.add_retcode(retcode);
|
rsp_builder.add_retcode(retcode);
|
||||||
rsp_builder.add_reason(rsp_reason);
|
rsp_builder.add_reason(rsp_reason);
|
||||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||||
rsp_builder.add_iteration(iteration);
|
rsp_builder.add_iteration(iteration);
|
||||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
auto rsp_exchange_keys = rsp_builder.Finish();
|
||||||
exchange_keys_resp_builder->Finish(rsp_exchange_keys);
|
fbb->Finish(rsp_exchange_keys);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherKeys::BuildGetKeys(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
void CipherKeys::BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||||
const int iteration, const std::string &next_req_time, bool is_good) {
|
const int iteration, const std::string &next_req_time, bool is_good) {
|
||||||
bool flag = true;
|
if (!is_good) {
|
||||||
if (is_good) {
|
|
||||||
// convert client keys to standard keys list.
|
|
||||||
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
|
|
||||||
MS_LOG(INFO) << "Get Keys: ";
|
|
||||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
|
||||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
|
||||||
if (record_public_keys.size() < cipher_init_->client_num_need_) {
|
|
||||||
MS_LOG(INFO) << "NOT READY. keys num: " << record_public_keys.size()
|
|
||||||
<< "clients num: " << cipher_init_->client_num_need_;
|
|
||||||
flag = false;
|
|
||||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
|
||||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
|
||||||
rsp_buider.add_retcode(retcode);
|
|
||||||
rsp_buider.add_iteration(iteration);
|
|
||||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
|
||||||
auto rsp_get_keys = rsp_buider.Finish();
|
|
||||||
|
|
||||||
fbb->Finish(rsp_get_keys);
|
|
||||||
} else {
|
|
||||||
for (auto iter = record_public_keys.begin(); iter != record_public_keys.end(); ++iter) {
|
|
||||||
// read (fl_id, c_pk, s_pk) from the map: record_public_keys_
|
|
||||||
std::string fl_id = iter->first;
|
|
||||||
MS_LOG(INFO) << "fl_id : " << fl_id;
|
|
||||||
|
|
||||||
// To serialize the members to a new Table:ClientPublicKeys
|
|
||||||
auto fbs_fl_id = fbb->CreateString(fl_id);
|
|
||||||
auto fbs_c_pk = fbb->CreateVector(iter->second[0].data(), iter->second[0].size());
|
|
||||||
auto fbs_s_pk = fbb->CreateVector(iter->second[1].data(), iter->second[1].size());
|
|
||||||
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk);
|
|
||||||
public_keys_list.push_back(cur_public_key);
|
|
||||||
}
|
|
||||||
auto remote_publickeys = fbb->CreateVector(public_keys_list);
|
|
||||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
|
||||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
|
||||||
rsp_buider.add_retcode(retcode);
|
|
||||||
rsp_buider.add_iteration(iteration);
|
|
||||||
rsp_buider.add_remote_publickeys(remote_publickeys);
|
|
||||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
|
||||||
auto rsp_get_keys = rsp_buider.Finish();
|
|
||||||
fbb->Finish(rsp_get_keys);
|
|
||||||
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||||
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||||
rsp_buider.add_retcode(retcode);
|
rsp_buider.add_retcode(retcode);
|
||||||
rsp_buider.add_iteration(iteration);
|
rsp_buider.add_iteration(iteration);
|
||||||
rsp_buider.add_next_req_time(fbs_next_req_time);
|
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||||
auto rsp_get_keys = rsp_buider.Finish();
|
auto rsp_get_keys = rsp_buider.Finish();
|
||||||
|
|
||||||
fbb->Finish(rsp_get_keys);
|
fbb->Finish(rsp_get_keys);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return flag;
|
const fl::PBMetadata &clients_keys_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientsKeys);
|
||||||
|
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||||
|
std::vector<flatbuffers::Offset<schema::ClientPublicKeys>> public_keys_list;
|
||||||
|
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
|
||||||
|
std::string fl_id = iter->first;
|
||||||
|
fl::KeysPb keys_pb = iter->second;
|
||||||
|
auto fbs_fl_id = fbb->CreateString(fl_id);
|
||||||
|
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||||
|
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
||||||
|
auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size());
|
||||||
|
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
|
||||||
|
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
|
||||||
|
auto fbs_pw_iv = fbb->CreateVector(pw_iv.data(), pw_iv.size());
|
||||||
|
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
|
||||||
|
auto fbs_pw_salt = fbb->CreateVector(pw_salt.data(), pw_salt.size());
|
||||||
|
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
|
||||||
|
public_keys_list.push_back(cur_public_key);
|
||||||
|
}
|
||||||
|
auto remote_publickeys = fbb->CreateVector(public_keys_list);
|
||||||
|
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||||
|
schema::ReturnExchangeKeysBuilder rsp_buider(*(fbb.get()));
|
||||||
|
rsp_buider.add_retcode(retcode);
|
||||||
|
rsp_buider.add_iteration(iteration);
|
||||||
|
rsp_buider.add_remote_publickeys(remote_publickeys);
|
||||||
|
rsp_buider.add_next_req_time(fbs_next_req_time);
|
||||||
|
auto rsp_get_keys = rsp_buider.Finish();
|
||||||
|
fbb->Finish(rsp_get_keys);
|
||||||
|
MS_LOG(INFO) << "CipherMgr::GetKeys Success";
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherKeys::ClearKeys() {
|
void CipherKeys::ClearKeys() {
|
||||||
|
|
|
@ -44,21 +44,19 @@ class CipherKeys {
|
||||||
|
|
||||||
// handle the client's request of get keys.
|
// handle the client's request of get keys.
|
||||||
bool GetKeys(const int cur_iterator, const std::string &next_req_time,
|
bool GetKeys(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::GetExchangeKeys *get_exchange_keys_req,
|
const schema::GetExchangeKeys *get_exchange_keys_req, const std::shared_ptr<fl::server::FBBuilder> &fbb);
|
||||||
const std::shared_ptr<fl::server::FBBuilder> &get_exchange_keys_resp_builder);
|
|
||||||
|
|
||||||
// handle the client's request of exchange keys.
|
// handle the client's request of exchange keys.
|
||||||
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
bool ExchangeKeys(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||||
const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder);
|
const std::shared_ptr<fl::server::FBBuilder> &fbb);
|
||||||
|
|
||||||
// build response code of get keys.
|
// build response code of get keys.
|
||||||
bool BuildGetKeys(const std::shared_ptr<fl::server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
void BuildGetKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||||
const int iteration, const std::string &next_req_time, bool is_good);
|
const int iteration, const std::string &next_req_time, bool is_good);
|
||||||
// build response code of exchange keys.
|
// build response code of exchange keys.
|
||||||
void BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &exchange_keys_resp_builder,
|
void BuildExchangeKeysRsp(std::shared_ptr<fl::server::FBBuilder> fbb, const schema::ResponseCode retcode,
|
||||||
const schema::ResponseCode retcode, const std::string &reason,
|
const std::string &reason, const std::string &next_req_time, const int iteration);
|
||||||
const std::string &next_req_time, const int iteration);
|
|
||||||
// clear the shared memory.
|
// clear the shared memory.
|
||||||
void ClearKeys();
|
void ClearKeys();
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ void CipherMetaStorage::GetClientListFromServer(const char *list_name, std::vect
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherMetaStorage::GetClientKeysFromServer(
|
void CipherMetaStorage::GetClientKeysFromServer(
|
||||||
const char *list_name, std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list) {
|
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) {
|
||||||
const fl::PBMetadata &clients_keys_pb_out =
|
const fl::PBMetadata &clients_keys_pb_out =
|
||||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||||
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||||
|
@ -60,12 +60,33 @@ void CipherMetaStorage::GetClientKeysFromServer(
|
||||||
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
|
// const PairClientKeys & pair_client_keys_pb = clients_keys_pb.client_keys(i);
|
||||||
std::string fl_id = iter->first;
|
std::string fl_id = iter->first;
|
||||||
fl::KeysPb keys_pb = iter->second;
|
fl::KeysPb keys_pb = iter->second;
|
||||||
std::vector<unsigned char> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||||
std::vector<unsigned char> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
||||||
std::vector<std::vector<unsigned char>> cur_keys;
|
std::vector<std::vector<uint8_t>> cur_keys;
|
||||||
cur_keys.push_back(cpk);
|
cur_keys.push_back(cpk);
|
||||||
cur_keys.push_back(spk);
|
cur_keys.push_back(spk);
|
||||||
clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<unsigned char>>>(fl_id, cur_keys));
|
clients_keys_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CipherMetaStorage::GetClientIVsFromServer(
|
||||||
|
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) {
|
||||||
|
const fl::PBMetadata &clients_keys_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||||
|
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||||
|
|
||||||
|
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
|
||||||
|
std::string fl_id = iter->first;
|
||||||
|
fl::KeysPb keys_pb = iter->second;
|
||||||
|
std::vector<uint8_t> ind_iv(keys_pb.ind_iv().begin(), keys_pb.ind_iv().end());
|
||||||
|
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
|
||||||
|
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
|
||||||
|
|
||||||
|
std::vector<std::vector<uint8_t>> cur_ivs;
|
||||||
|
cur_ivs.push_back(ind_iv);
|
||||||
|
cur_ivs.push_back(pw_iv);
|
||||||
|
cur_ivs.push_back(pw_salt);
|
||||||
|
clients_ivs_list->insert(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_ivs));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,9 +94,18 @@ bool CipherMetaStorage::GetClientNoisesFromServer(const char *list_name, std::ve
|
||||||
const fl::PBMetadata &clients_noises_pb_out =
|
const fl::PBMetadata &clients_noises_pb_out =
|
||||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||||
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||||
|
int count = 0;
|
||||||
|
int count_thld = 100;
|
||||||
while (clients_noises_pb.has_one_client_noises() == false) {
|
while (clients_noises_pb.has_one_client_noises() == false) {
|
||||||
MS_LOG(INFO) << "GetClientNoisesFromServer NULL.";
|
int register_time = 500;
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
std::this_thread::sleep_for(std::chrono::milliseconds(register_time));
|
||||||
|
count++;
|
||||||
|
if (count >= count_thld) break;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "GetClientNoisesFromServer Count: " << count;
|
||||||
|
if (clients_noises_pb.has_one_client_noises() == false) {
|
||||||
|
MS_LOG(ERROR) << "GetClientNoisesFromServer NULL.";
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(),
|
cur_public_noise->assign(clients_noises_pb.one_client_noises().noise().begin(),
|
||||||
clients_noises_pb.one_client_noises().noise().end());
|
clients_noises_pb.one_client_noises().noise().end());
|
||||||
|
@ -98,14 +128,14 @@ bool CipherMetaStorage::GetPrimeFromServer(const char *prime_name, unsigned char
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
|
bool CipherMetaStorage::UpdateClientToServer(const char *list_name, const std::string &fl_id) {
|
||||||
bool retcode = true;
|
|
||||||
fl::FLId fl_id_pb;
|
fl::FLId fl_id_pb;
|
||||||
fl_id_pb.set_fl_id(fl_id);
|
fl_id_pb.set_fl_id(fl_id);
|
||||||
fl::PBMetadata client_pb;
|
fl::PBMetadata client_pb;
|
||||||
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
|
client_pb.mutable_fl_id()->MergeFrom(fl_id_pb);
|
||||||
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
|
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_pb);
|
||||||
return retcode;
|
return retcode;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
|
void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &prime) {
|
||||||
MS_LOG(INFO) << "register prime: " << prime;
|
MS_LOG(INFO) << "register prime: " << prime;
|
||||||
fl::Prime prime_id_pb;
|
fl::Prime prime_id_pb;
|
||||||
|
@ -117,9 +147,9 @@ void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
||||||
const std::vector<std::vector<unsigned char>> &cur_public_key) {
|
const std::vector<std::vector<uint8_t>> &cur_public_key) {
|
||||||
bool retcode = true;
|
size_t correct_size = 2;
|
||||||
if (cur_public_key.size() < 2) {
|
if (cur_public_key.size() < correct_size) {
|
||||||
MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
|
MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -132,7 +162,73 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std
|
||||||
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
||||||
fl::PBMetadata client_and_keys_pb;
|
fl::PBMetadata client_and_keys_pb;
|
||||||
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
||||||
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||||
|
return retcode;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
|
||||||
|
const schema::RequestExchangeKeys *exchange_keys_req) {
|
||||||
|
std::string fl_id = exchange_keys_req->fl_id()->str();
|
||||||
|
auto fbs_cpk = exchange_keys_req->c_pk();
|
||||||
|
auto fbs_spk = exchange_keys_req->s_pk();
|
||||||
|
if (fbs_cpk == nullptr || fbs_spk == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "public key from exchange_keys_req is null";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t spk_len = fbs_spk->size();
|
||||||
|
size_t cpk_len = fbs_cpk->size();
|
||||||
|
|
||||||
|
// transform fbs (fbs_cpk & fbs_spk) to a vector: public_key
|
||||||
|
std::vector<std::vector<uint8_t>> cur_public_key;
|
||||||
|
std::vector<uint8_t> cpk(cpk_len);
|
||||||
|
std::vector<uint8_t> spk(spk_len);
|
||||||
|
bool ret_create_code_cpk = CreateArray<uint8_t>(&cpk, *fbs_cpk);
|
||||||
|
bool ret_create_code_spk = CreateArray<uint8_t>(&spk, *fbs_spk);
|
||||||
|
if (!(ret_create_code_cpk && ret_create_code_spk)) {
|
||||||
|
MS_LOG(ERROR) << "create array for public keys failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
cur_public_key.push_back(cpk);
|
||||||
|
cur_public_key.push_back(spk);
|
||||||
|
|
||||||
|
auto fbs_ind_iv = exchange_keys_req->ind_iv();
|
||||||
|
std::vector<char> ind_iv;
|
||||||
|
if (fbs_ind_iv == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "ind_iv in exchange_keys_req is nullptr";
|
||||||
|
} else {
|
||||||
|
ind_iv.assign(fbs_ind_iv->begin(), fbs_ind_iv->end());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fbs_pw_iv = exchange_keys_req->pw_iv();
|
||||||
|
std::vector<char> pw_iv;
|
||||||
|
if (fbs_pw_iv == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "pw_iv in exchange_keys_req is nullptr";
|
||||||
|
} else {
|
||||||
|
pw_iv.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fbs_pw_salt = exchange_keys_req->pw_salt();
|
||||||
|
std::vector<char> pw_salt;
|
||||||
|
if (fbs_pw_salt == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "pw_salt in exchange_keys_req is nullptr";
|
||||||
|
} else {
|
||||||
|
pw_salt.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// update new item to memory server.
|
||||||
|
fl::KeysPb keys;
|
||||||
|
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
|
||||||
|
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
|
||||||
|
keys.set_ind_iv(ind_iv.data(), ind_iv.size());
|
||||||
|
keys.set_pw_iv(pw_iv.data(), pw_iv.size());
|
||||||
|
keys.set_pw_salt(pw_salt.data(), pw_salt.size());
|
||||||
|
fl::PairClientKeys pair_client_keys_pb;
|
||||||
|
pair_client_keys_pb.set_fl_id(fl_id);
|
||||||
|
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
||||||
|
fl::PBMetadata client_and_keys_pb;
|
||||||
|
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
||||||
|
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||||
return retcode;
|
return retcode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,13 +238,13 @@ bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const s
|
||||||
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
|
*noises_pb.mutable_noise() = {cur_public_noise.begin(), cur_public_noise.end()};
|
||||||
fl::PBMetadata client_noises_pb;
|
fl::PBMetadata client_noises_pb;
|
||||||
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
|
client_noises_pb.mutable_one_client_noises()->MergeFrom(noises_pb);
|
||||||
return fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
|
bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_noises_pb);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherMetaStorage::UpdateClientShareToServer(
|
bool CipherMetaStorage::UpdateClientShareToServer(
|
||||||
const char *list_name, const std::string &fl_id,
|
const char *list_name, const std::string &fl_id,
|
||||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
|
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *shares) {
|
||||||
bool retcode = true;
|
|
||||||
int size_shares = shares->size();
|
int size_shares = shares->size();
|
||||||
fl::SharesPb shares_pb;
|
fl::SharesPb shares_pb;
|
||||||
for (int index = 0; index < size_shares; ++index) {
|
for (int index = 0; index < size_shares; ++index) {
|
||||||
|
@ -166,14 +262,17 @@ bool CipherMetaStorage::UpdateClientShareToServer(
|
||||||
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
|
pair_client_shares_pb.mutable_client_shares()->MergeFrom(shares_pb);
|
||||||
fl::PBMetadata client_and_shares_pb;
|
fl::PBMetadata client_and_shares_pb;
|
||||||
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
|
client_and_shares_pb.mutable_pair_client_shares()->MergeFrom(pair_client_shares_pb);
|
||||||
retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
|
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_shares_pb);
|
||||||
return retcode;
|
return retcode;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherMetaStorage::RegisterClass() {
|
void CipherMetaStorage::RegisterClass() {
|
||||||
fl::PBMetadata exchange_kyes_client_list;
|
fl::PBMetadata exchange_keys_client_list;
|
||||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
|
||||||
exchange_kyes_client_list);
|
exchange_keys_client_list);
|
||||||
|
fl::PBMetadata get_keys_client_list;
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetKeysClientList,
|
||||||
|
get_keys_client_list);
|
||||||
fl::PBMetadata clients_keys;
|
fl::PBMetadata clients_keys;
|
||||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
|
||||||
fl::PBMetadata reconstruct_client_list;
|
fl::PBMetadata reconstruct_client_list;
|
||||||
|
@ -185,9 +284,15 @@ void CipherMetaStorage::RegisterClass() {
|
||||||
fl::PBMetadata share_secretes_client_list;
|
fl::PBMetadata share_secretes_client_list;
|
||||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxShareSecretsClientList,
|
||||||
share_secretes_client_list);
|
share_secretes_client_list);
|
||||||
|
fl::PBMetadata get_secretes_client_list;
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetSecretsClientList,
|
||||||
|
get_secretes_client_list);
|
||||||
fl::PBMetadata clients_encrypt_shares;
|
fl::PBMetadata clients_encrypt_shares;
|
||||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsEncryptedShares,
|
||||||
clients_encrypt_shares);
|
clients_encrypt_shares);
|
||||||
|
fl::PBMetadata get_update_clients_list;
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetUpdateModelClientList,
|
||||||
|
get_update_clients_list);
|
||||||
}
|
}
|
||||||
} // namespace armour
|
} // namespace armour
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -31,23 +31,39 @@
|
||||||
#include "fl/server/distributed_metadata_store.h"
|
#include "fl/server/distributed_metadata_store.h"
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
|
||||||
|
#define IND_IV_INDEX 0
|
||||||
|
#define PW_IV_INDEX 1
|
||||||
|
#define PW_SALT_INDEX 2
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
|
|
||||||
|
template <typename T1>
|
||||||
|
bool CreateArray(std::vector<T1> *newData, const flatbuffers::Vector<T1> &fbs_arr) {
|
||||||
|
if (newData == nullptr) return false;
|
||||||
|
size_t size = newData->size();
|
||||||
|
size_t size_fbs_arr = fbs_arr.size();
|
||||||
|
if (size != size_fbs_arr) return false;
|
||||||
|
for (size_t i = 0; i < size; ++i) {
|
||||||
|
newData->at(i) = fbs_arr.Get(i);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int SHARE_MAX_SIZE = 256;
|
constexpr int SHARE_MAX_SIZE = 256;
|
||||||
constexpr int SECRET_MAX_LEN_DOUBLE = 66;
|
constexpr int SECRET_MAX_LEN_DOUBLE = 66;
|
||||||
|
|
||||||
struct clientshare_str {
|
struct clientshare_str {
|
||||||
std::string fl_id;
|
std::string fl_id;
|
||||||
std::vector<unsigned char> share;
|
std::vector<uint8_t> share;
|
||||||
int index;
|
int index;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CipherPublicPara {
|
struct CipherPublicPara {
|
||||||
int t;
|
int t;
|
||||||
int g;
|
int g;
|
||||||
unsigned char prime[PRIME_MAX_LEN];
|
uint8_t prime[PRIME_MAX_LEN];
|
||||||
unsigned char p[SECRET_MAX_LEN];
|
uint8_t p[SECRET_MAX_LEN];
|
||||||
float dp_eps;
|
float dp_eps;
|
||||||
float dp_delta;
|
float dp_delta;
|
||||||
float dp_norm_clip;
|
float dp_norm_clip;
|
||||||
|
@ -62,7 +78,7 @@ class CipherMetaStorage {
|
||||||
// Register Prime.
|
// Register Prime.
|
||||||
void RegisterPrime(const char *list_name, const std::string &prime);
|
void RegisterPrime(const char *list_name, const std::string &prime);
|
||||||
// Get tprime from shared server.
|
// Get tprime from shared server.
|
||||||
bool GetPrimeFromServer(const char *prime_name, unsigned char *prime);
|
bool GetPrimeFromServer(const char *prime_name, uint8_t *prime);
|
||||||
// Get client shares from shared server.
|
// Get client shares from shared server.
|
||||||
void GetClientSharesFromServer(const char *list_name,
|
void GetClientSharesFromServer(const char *list_name,
|
||||||
std::map<std::string, std::vector<clientshare_str>> *clients_shares_list);
|
std::map<std::string, std::vector<clientshare_str>> *clients_shares_list);
|
||||||
|
@ -70,14 +86,18 @@ class CipherMetaStorage {
|
||||||
void GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list);
|
void GetClientListFromServer(const char *list_name, std::vector<std::string> *clients_list);
|
||||||
// Get client keys from shared server.
|
// Get client keys from shared server.
|
||||||
void GetClientKeysFromServer(const char *list_name,
|
void GetClientKeysFromServer(const char *list_name,
|
||||||
std::map<std::string, std::vector<std::vector<unsigned char>>> *clients_keys_list);
|
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list);
|
||||||
|
void GetClientIVsFromServer(const char *list_name,
|
||||||
|
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list);
|
||||||
// Get client noises from shared server.
|
// Get client noises from shared server.
|
||||||
bool GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise);
|
bool GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise);
|
||||||
// Update client fl_id to shared server.
|
// Update client fl_id to shared server.
|
||||||
bool UpdateClientToServer(const char *list_name, const std::string &fl_id);
|
bool UpdateClientToServer(const char *list_name, const std::string &fl_id);
|
||||||
// Update client key to shared server.
|
// Update client key to shared server.
|
||||||
bool UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
bool UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
||||||
const std::vector<std::vector<unsigned char>> &cur_public_key);
|
const std::vector<std::vector<uint8_t>> &cur_public_key);
|
||||||
|
// Update client key with signature to shared server.
|
||||||
|
bool UpdateClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req);
|
||||||
// Update client noise to shared server.
|
// Update client noise to shared server.
|
||||||
bool UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise);
|
bool UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise);
|
||||||
// Update client share to shared server.
|
// Update client share to shared server.
|
||||||
|
|
|
@ -16,33 +16,35 @@
|
||||||
|
|
||||||
#include "fl/armour/cipher/cipher_reconstruct.h"
|
#include "fl/armour/cipher/cipher_reconstruct.h"
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
#include "fl/armour/secure_protocol/random.h"
|
#include "fl/armour/secure_protocol/masking.h"
|
||||||
#include "fl/armour/secure_protocol/key_agreement.h"
|
#include "fl/armour/secure_protocol/key_agreement.h"
|
||||||
#include "fl/armour/cipher/cipher_meta_storage.h"
|
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
bool CipherReconStruct::CombineMask(
|
bool CipherReconStruct::CombineMask(std::vector<Share *> *shares_tmp,
|
||||||
std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_keys,
|
std::map<std::string, std::vector<float>> *client_noise,
|
||||||
const std::vector<std::string> &clients_share_list,
|
const std::vector<std::string> &clients_share_list,
|
||||||
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
|
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
|
||||||
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
|
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
|
||||||
const std::vector<string> &client_list) {
|
const std::vector<string> &client_list,
|
||||||
|
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs) {
|
||||||
bool retcode = true;
|
bool retcode = true;
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||||
retcode = false;
|
retcode = false;
|
||||||
#else
|
#else
|
||||||
|
if (shares_tmp == nullptr || client_noise == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "shares_tmp or client_noise is nullptr.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
|
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
|
||||||
// define flag_share: judge we need b or s
|
// define flag_share: judge we need b or s
|
||||||
bool flag_share = true;
|
bool flag_share = true;
|
||||||
const std::string fl_id = iter->first;
|
const std::string fl_id = iter->first;
|
||||||
std::vector<std::string>::const_iterator ptr = client_list.begin();
|
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) {
|
||||||
for (; ptr < client_list.end(); ++ptr) {
|
// the client is online
|
||||||
if (*ptr == fl_id) {
|
flag_share = false;
|
||||||
flag_share = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "fl_id_src : " << fl_id;
|
MS_LOG(INFO) << "fl_id_src : " << fl_id;
|
||||||
BIGNUM *prime = BN_new();
|
BIGNUM *prime = BN_new();
|
||||||
|
@ -61,7 +63,6 @@ bool CipherReconStruct::CombineMask(
|
||||||
MS_LOG(ERROR) << "shares_tmp copy failed";
|
MS_LOG(ERROR) << "shares_tmp copy failed";
|
||||||
retcode = false;
|
retcode = false;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "fl_id_des : " << (iter->second)[i].fl_id;
|
|
||||||
std::string print_share_data(reinterpret_cast<const char *>(shares_tmp->at(i)->data), shares_tmp->at(i)->len);
|
std::string print_share_data(reinterpret_cast<const char *>(shares_tmp->at(i)->data), shares_tmp->at(i)->len);
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "end assign secrets shares to public shares ";
|
MS_LOG(INFO) << "end assign secrets shares to public shares ";
|
||||||
|
@ -74,23 +75,42 @@ bool CipherReconStruct::CombineMask(
|
||||||
MS_LOG(INFO) << "combine secrets shares Success.";
|
MS_LOG(INFO) << "combine secrets shares Success.";
|
||||||
|
|
||||||
if (flag_share) {
|
if (flag_share) {
|
||||||
MS_LOG(INFO) << "start get complete s_uv.";
|
// reconstruct pairwise noise
|
||||||
|
MS_LOG(INFO) << "start reconstruct pairwise noise.";
|
||||||
std::vector<float> noise(cipher_init_->featuremap_, 0.0);
|
std::vector<float> noise(cipher_init_->featuremap_, 0.0);
|
||||||
if (GetSuvNoise(clients_share_list, record_public_keys, fl_id, &noise, secret, length) == false)
|
if (GetSuvNoise(clients_share_list, record_public_keys, client_ivs, fl_id, &noise, secret, length) == false)
|
||||||
retcode = false;
|
retcode = false;
|
||||||
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
||||||
MS_LOG(INFO) << " fl_id : " << fl_id;
|
MS_LOG(INFO) << " fl_id : " << fl_id;
|
||||||
MS_LOG(INFO) << "end get complete s_uv.";
|
MS_LOG(INFO) << "end get complete s_uv.";
|
||||||
} else {
|
} else {
|
||||||
|
// reconstruct individual noise
|
||||||
|
MS_LOG(INFO) << "start reconstruct individual noise.";
|
||||||
std::vector<float> noise;
|
std::vector<float> noise;
|
||||||
if (Random::RandomAESCTR(&noise, cipher_init_->featuremap_, (const unsigned char *)secret, SECRET_MAX_LEN) < 0)
|
auto it = client_ivs.find(fl_id);
|
||||||
|
if (it == client_ivs.end()) {
|
||||||
|
MS_LOG(ERROR) << "cannot get ivs for client: " << fl_id;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (it->second.size() != IV_NUM) {
|
||||||
|
MS_LOG(ERROR) << "get " << it->second.size() << " ivs, the iv num required is: " << IV_NUM;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<uint8_t> ind_iv = it->second[0];
|
||||||
|
if (Masking::GetMasking(&noise, cipher_init_->featuremap_, (const uint8_t *)secret, SECRET_MAX_LEN,
|
||||||
|
ind_iv.data(), ind_iv.size()) < 0)
|
||||||
retcode = false;
|
retcode = false;
|
||||||
for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) {
|
for (size_t index_noise = 0; index_noise < cipher_init_->featuremap_; index_noise++) {
|
||||||
noise[index_noise] *= -1;
|
noise[index_noise] *= -1;
|
||||||
}
|
}
|
||||||
client_keys->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
client_noise->insert(std::pair<std::string, std::vector<float>>(fl_id, noise));
|
||||||
MS_LOG(INFO) << " fl_id : " << fl_id;
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "reconstruct secret failed: the number of secret shares for fl_id: " << fl_id
|
||||||
|
<< " is not enough";
|
||||||
|
MS_LOG(ERROR) << "get " << iter->second.size()
|
||||||
|
<< "shares, however the secrets_minnums_ required is: " << cipher_init_->secrets_minnums_;
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -98,173 +118,188 @@ bool CipherReconStruct::CombineMask(
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &client_list) {
|
bool CipherReconStruct::ReconstructSecretsGenNoise(const std::vector<string> &client_list) {
|
||||||
// get reconstruct_secret_list_ori from memory server
|
// get reconstruct_secrets from memory server
|
||||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
|
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecretsGenNoise START";
|
||||||
bool retcode = true;
|
bool retcode = true;
|
||||||
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list_ori;
|
std::map<std::string, std::vector<clientshare_str>> reconstruct_secrets;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
||||||
&reconstruct_secret_list_ori);
|
&reconstruct_secrets);
|
||||||
std::map<std::string, std::vector<std::vector<unsigned char>>> record_public_keys;
|
std::map<std::string, std::vector<std::vector<uint8_t>>> record_public_keys;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &record_public_keys);
|
||||||
std::vector<std::string> clients_reconstruct_list;
|
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
|
std::map<std::string, std::vector<std::vector<uint8_t>>> client_ivs;
|
||||||
&clients_reconstruct_list);
|
cipher_init_->cipher_meta_storage_.GetClientIVsFromServer(fl::server::kCtxClientsKeys, &client_ivs);
|
||||||
|
|
||||||
std::vector<std::string> clients_share_list;
|
std::vector<std::string> clients_share_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||||
&clients_share_list);
|
&clients_share_list);
|
||||||
if (reconstruct_secret_list_ori.size() != clients_reconstruct_list.size() ||
|
if (record_public_keys.size() < cipher_init_->exchange_key_threshold ||
|
||||||
record_public_keys.size() < cipher_init_->client_num_need_ ||
|
clients_share_list.size() < cipher_init_->share_secrets_threshold ||
|
||||||
clients_share_list.size() < cipher_init_->share_clients_num_need_) {
|
record_public_keys.size() != client_ivs.size()) {
|
||||||
|
MS_LOG(ERROR) << "send share client size: " << clients_share_list.size()
|
||||||
|
<< ", send public-key client size: " << record_public_keys.size()
|
||||||
|
<< ", send ivs client size: " << client_ivs.size();
|
||||||
MS_LOG(ERROR) << "get data from server memory failed";
|
MS_LOG(ERROR) << "get data from server memory failed";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list;
|
std::map<std::string, std::vector<clientshare_str>> reconstruct_secret_list;
|
||||||
ConvertSharesToShares(reconstruct_secret_list_ori, &reconstruct_secret_list);
|
if (!ConvertSharesToShares(reconstruct_secrets, &reconstruct_secret_list)) {
|
||||||
|
MS_LOG(ERROR) << "ConvertSharesToShares failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(ERROR) << "recombined shares";
|
||||||
|
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
|
||||||
|
MS_LOG(ERROR) << "fl_id: " << iter->first;
|
||||||
|
MS_LOG(ERROR) << "share size: " << iter->second.size();
|
||||||
|
}
|
||||||
std::vector<Share *> shares_tmp;
|
std::vector<Share *> shares_tmp;
|
||||||
if (MallocShares(&shares_tmp, cipher_init_->secrets_minnums_) == false) {
|
if (!MallocShares(&shares_tmp, cipher_init_->secrets_minnums_)) {
|
||||||
MS_LOG(ERROR) << "Reconstruct malloc shares_tmp invalid.";
|
MS_LOG(ERROR) << "Reconstruct malloc shares_tmp invalid.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Reconstruct client list: ";
|
|
||||||
std::vector<std::string>::const_iterator ptr_tmp = client_list.begin();
|
|
||||||
for (; ptr_tmp < client_list.end(); ++ptr_tmp) {
|
|
||||||
MS_LOG(INFO) << *ptr_tmp;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "Reconstruct secrets shares: ";
|
MS_LOG(INFO) << "Reconstruct secrets shares: ";
|
||||||
std::map<std::string, std::vector<float>> client_keys;
|
std::map<std::string, std::vector<float>> client_noise;
|
||||||
|
retcode = CombineMask(&shares_tmp, &client_noise, clients_share_list, record_public_keys, reconstruct_secret_list,
|
||||||
retcode = CombineMask(&shares_tmp, &client_keys, clients_share_list, record_public_keys, reconstruct_secret_list,
|
client_list, client_ivs);
|
||||||
client_list);
|
|
||||||
|
|
||||||
DeleteShares(&shares_tmp);
|
DeleteShares(&shares_tmp);
|
||||||
if (retcode) {
|
if (retcode) {
|
||||||
std::vector<float> noise;
|
std::vector<float> noise;
|
||||||
if (GetNoiseMasksSum(&noise, client_keys) == false) {
|
if (!GetNoiseMasksSum(&noise, client_noise)) {
|
||||||
MS_LOG(ERROR) << " GetNoiseMasksSum failed";
|
MS_LOG(ERROR) << " GetNoiseMasksSum failed";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
client_keys.clear();
|
client_noise.clear();
|
||||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
|
MS_LOG(INFO) << " ReconstructSecretsGenNoise updata noise to server";
|
||||||
|
|
||||||
if (cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise) == false)
|
if (!cipher_init_->cipher_meta_storage_.UpdateClientNoiseToServer(fl::server::kCtxClientNoises, noise)) {
|
||||||
|
MS_LOG(ERROR) << " ReconstructSecretsGenNoise failed. because UpdateClientNoiseToServer failed";
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
|
MS_LOG(INFO) << " ReconstructSecretsGenNoise Success";
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(INFO) << " ReconstructSecretsGenNoise failed. because gen noise inside failed";
|
MS_LOG(ERROR) << " ReconstructSecretsGenNoise failed. because gen noise inside failed";
|
||||||
}
|
}
|
||||||
|
|
||||||
return retcode;
|
return retcode;
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconstruct secrets
|
// reconstruct secrets
|
||||||
bool CipherReconStruct::ReconstructSecrets(
|
bool CipherReconStruct::ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||||
const int cur_iterator, const std::string &next_req_time, const schema::SendReconstructSecret *reconstruct_secret_req,
|
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||||
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
|
const std::shared_ptr<fl::server::FBBuilder> &fbb,
|
||||||
const std::vector<std::string> &client_list) {
|
const std::vector<std::string> &client_list) {
|
||||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
|
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets START";
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
if (reconstruct_secret_req == nullptr || reconstruct_secret_resp_builder == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr. ";
|
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_RequestError,
|
|
||||||
"Request is nullptr or Response builder is nullptr.", cur_iterator, next_req_time);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (client_list.size() < cipher_init_->reconstruct_clients_num_need_) {
|
|
||||||
MS_LOG(ERROR) << "illegal parameters. update model client_list size: " << client_list.size();
|
|
||||||
BuildReconstructSecretsRsp(
|
|
||||||
reconstruct_secret_resp_builder, schema::ResponseCode_RequestError,
|
|
||||||
"illegal parameters: update model client_list size must larger than reconstruct_clients_num_need", cur_iterator,
|
|
||||||
next_req_time);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
std::vector<std::string> clients_reconstruct_list;
|
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxReconstructClientList,
|
|
||||||
&clients_reconstruct_list);
|
|
||||||
std::map<std::string, std::vector<clientshare_str>> clients_shares_all;
|
|
||||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
|
||||||
&clients_shares_all);
|
|
||||||
|
|
||||||
size_t count_client_num = clients_shares_all.size();
|
if (reconstruct_secret_req == nullptr) {
|
||||||
if (count_client_num != clients_reconstruct_list.size()) {
|
std::string reason = "Request is nullptr";
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
MS_LOG(ERROR) << reason;
|
||||||
"shares client size and client size are not equal.", cur_iterator, next_req_time);
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, cur_iterator, next_req_time);
|
||||||
MS_LOG(ERROR) << "shares client size and client size are not equal.";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
int iterator = reconstruct_secret_req->iteration();
|
int iterator = reconstruct_secret_req->iteration();
|
||||||
std::string fl_id = reconstruct_secret_req->fl_id()->str();
|
std::string fl_id = reconstruct_secret_req->fl_id()->str();
|
||||||
if (iterator != cur_iterator) {
|
if (iterator != cur_iterator) {
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||||
"The iteration round of the client does not match the current iteration.", cur_iterator,
|
"The iteration round of the client does not match the current iteration.", cur_iterator,
|
||||||
next_req_time);
|
next_req_time);
|
||||||
MS_LOG(ERROR) << "Client " << fl_id << " The iteration round of the client does not match the current iteration.";
|
MS_LOG(ERROR) << "Client " << fl_id << " The iteration round of the client does not match the current iteration.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (find(client_list.begin(), client_list.end(), fl_id) == client_list.end()) { // client not in client list.
|
if (client_list.size() < cipher_init_->reconstruct_secrets_threshold) {
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
MS_LOG(ERROR) << "illegal parameters. update model client_list size: " << client_list.size();
|
||||||
"The client is not in update model client list.", cur_iterator, next_req_time);
|
BuildReconstructSecretsRsp(
|
||||||
MS_LOG(ERROR) << "The client " << fl_id << " is not in update model client list.";
|
fbb, schema::ResponseCode_RequestError,
|
||||||
|
"illegal parameters: update model client_list size must larger than reconstruct_clients_num_need", cur_iterator,
|
||||||
|
next_req_time);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (find(clients_reconstruct_list.begin(), clients_reconstruct_list.end(), fl_id) != clients_reconstruct_list.end()) {
|
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
|
std::vector<std::string> get_clients_list;
|
||||||
"Client has sended messages.", cur_iterator, next_req_time);
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetUpdateModelClientList,
|
||||||
|
&get_clients_list);
|
||||||
|
// client not in get client list.
|
||||||
|
if (find(get_clients_list.begin(), get_clients_list.end(), fl_id) == get_clients_list.end()) {
|
||||||
|
std::string reason;
|
||||||
|
MS_LOG(INFO) << "The client " << fl_id << " is not in get update model client list.";
|
||||||
|
// client in update model client list.
|
||||||
|
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) {
|
||||||
|
reason = "The client " + fl_id + " is not in get clients list, but in update model client list.";
|
||||||
|
MS_LOG(INFO) << reason;
|
||||||
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, reason, cur_iterator, next_req_time);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
reason = "The client " + fl_id + " is not in get clients list, and not in update model client list.";
|
||||||
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, "The client is not in update model client list.",
|
||||||
|
cur_iterator, next_req_time);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, std::vector<clientshare_str>> reconstruct_shares;
|
||||||
|
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsReconstructShares,
|
||||||
|
&reconstruct_shares);
|
||||||
|
size_t count_client_num = reconstruct_shares.size();
|
||||||
|
if (reconstruct_shares.find(fl_id) != reconstruct_shares.end()) {
|
||||||
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Client has sended messages.", cur_iterator,
|
||||||
|
next_req_time);
|
||||||
MS_LOG(INFO) << "Error, client " << fl_id << " has sended messages.";
|
MS_LOG(INFO) << "Error, client " << fl_id << " has sended messages.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
|
auto reconstruct_secret_shares = reconstruct_secret_req->reconstruct_secret_shares();
|
||||||
bool retcode_client =
|
bool retcode_client =
|
||||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id);
|
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxReconstructClientList, fl_id);
|
||||||
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
||||||
fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
|
fl::server::kCtxClientsReconstructShares, fl_id, reconstruct_secret_shares);
|
||||||
if (!(retcode_client && retcode_share)) {
|
if (!(retcode_client && retcode_share)) {
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "reconstruct update shares or client failed.",
|
||||||
"reconstruct update shares or client failed.", cur_iterator, next_req_time);
|
cur_iterator, next_req_time);
|
||||||
MS_LOG(ERROR) << "reconstruct update shares or client failed.";
|
MS_LOG(ERROR) << "reconstruct update shares or client failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
count_client_num = count_client_num + 1;
|
count_client_num = count_client_num + 1;
|
||||||
if (count_client_num < cipher_init_->reconstruct_clients_num_need_) {
|
if (count_client_num < cipher_init_->reconstruct_secrets_threshold) {
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||||
"Success,but the server is not ready to reconstruct secret yet.", cur_iterator,
|
"Success, but the server is not ready to reconstruct secret yet.", cur_iterator,
|
||||||
next_req_time);
|
next_req_time);
|
||||||
MS_LOG(INFO) << "ReconstructSecrets" << fl_id << " Success, but count " << count_client_num << "is not enough.";
|
MS_LOG(INFO) << "ReconstructSecrets " << fl_id << " Success, but count " << count_client_num << "is not enough.";
|
||||||
return true;
|
return true;
|
||||||
} else {
|
|
||||||
bool retcode_result = true;
|
|
||||||
const fl::PBMetadata &clients_noises_pb_out =
|
|
||||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
|
|
||||||
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
|
||||||
if (clients_noises_pb.has_one_client_noises() == false) {
|
|
||||||
MS_LOG(INFO) << "Success,the secret will be reconstructed.";
|
|
||||||
retcode_result = ReconstructSecretsGenNoise(client_list);
|
|
||||||
if (retcode_result) {
|
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
|
|
||||||
"Success,the secret is reconstructing.", cur_iterator, next_req_time);
|
|
||||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success, reconstruct ok.";
|
|
||||||
} else {
|
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_OutOfTime,
|
|
||||||
"the secret restructs failed.", cur_iterator, next_req_time);
|
|
||||||
MS_LOG(ERROR) << "the secret restructs failed.";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
BuildReconstructSecretsRsp(reconstruct_secret_resp_builder, schema::ResponseCode_SUCCEED,
|
|
||||||
"Clients' number is full.", cur_iterator, next_req_time);
|
|
||||||
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success : no need reconstruct.";
|
|
||||||
}
|
|
||||||
clock_t end_time = clock();
|
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
|
||||||
MS_LOG(INFO) << "Reconstruct get + gennoise data time is : " << duration;
|
|
||||||
return retcode_result;
|
|
||||||
}
|
}
|
||||||
|
const fl::PBMetadata &clients_noises_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientNoises);
|
||||||
|
const fl::ClientNoises &clients_noises_pb = clients_noises_pb_out.client_noises();
|
||||||
|
if (clients_noises_pb.has_one_client_noises() == false) {
|
||||||
|
MS_LOG(INFO) << "Success, the secret will be reconstructed.";
|
||||||
|
if (ReconstructSecretsGenNoise(client_list)) {
|
||||||
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Success,the secret is reconstructing.",
|
||||||
|
cur_iterator, next_req_time);
|
||||||
|
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success, reconstruct ok.";
|
||||||
|
} else {
|
||||||
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "the secret restructs failed.", cur_iterator,
|
||||||
|
next_req_time);
|
||||||
|
MS_LOG(ERROR) << "CipherReconStruct::ReconstructSecrets" << fl_id << " failed.";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED, "Clients' number is full.", cur_iterator,
|
||||||
|
next_req_time);
|
||||||
|
MS_LOG(INFO) << "CipherReconStruct::ReconstructSecrets" << fl_id << " Success : no need reconstruct.";
|
||||||
|
}
|
||||||
|
clock_t end_time = clock();
|
||||||
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
|
MS_LOG(INFO) << "Reconstruct get + gennoise data time is : " << duration;
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
|
bool CipherReconStruct::GetNoiseMasksSum(std::vector<float> *result,
|
||||||
const std::map<std::string, std::vector<float>> &client_keys) {
|
const std::map<std::string, std::vector<float>> &client_noise) {
|
||||||
std::vector<float> sum(cipher_init_->featuremap_, 0.0);
|
std::vector<float> sum(cipher_init_->featuremap_, 0.0);
|
||||||
for (auto iter = client_keys.begin(); iter != client_keys.end(); iter++) {
|
for (auto iter = client_noise.begin(); iter != client_noise.end(); iter++) {
|
||||||
if (iter->second.size() != cipher_init_->featuremap_) {
|
if (iter->second.size() != cipher_init_->featuremap_) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -301,35 +336,52 @@ void CipherReconStruct::BuildReconstructSecretsRsp(const std::shared_ptr<fl::ser
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherReconStruct::GetSuvNoise(
|
bool CipherReconStruct::GetSuvNoise(const std::vector<std::string> &clients_share_list,
|
||||||
const std::vector<std::string> &clients_share_list,
|
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
|
||||||
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys, const string &fl_id,
|
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs,
|
||||||
std::vector<float> *noise, uint8_t *secret, int length) {
|
const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length) {
|
||||||
for (auto p_key = clients_share_list.begin(); p_key != clients_share_list.end(); ++p_key) {
|
for (auto p_key = clients_share_list.begin(); p_key != clients_share_list.end(); ++p_key) {
|
||||||
if (*p_key != fl_id) {
|
if (*p_key != fl_id) {
|
||||||
PrivateKey *privKey1 = KeyAgreement::FromPrivateBytes((unsigned char *)secret, length);
|
PrivateKey *privKey = KeyAgreement::FromPrivateBytes(secret, length);
|
||||||
if (privKey1 == NULL) {
|
if (privKey == NULL) {
|
||||||
MS_LOG(ERROR) << "create privKey1 failed\n";
|
MS_LOG(ERROR) << "create privKey failed\n";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::vector<unsigned char> public_key = record_public_keys.at(*p_key)[1];
|
std::vector<uint8_t> public_key = record_public_keys.at(*p_key)[1];
|
||||||
PublicKey *pubKey1 = KeyAgreement::FromPublicBytes(public_key.data(), public_key.size());
|
std::string iv_fl_id;
|
||||||
if (pubKey1 == NULL) {
|
if (fl_id < *p_key) {
|
||||||
MS_LOG(ERROR) << "create pubKey1 failed\n";
|
iv_fl_id = fl_id;
|
||||||
|
} else {
|
||||||
|
iv_fl_id = *p_key;
|
||||||
|
}
|
||||||
|
auto iter = client_ivs.find(iv_fl_id);
|
||||||
|
if (iter == client_ivs.end()) {
|
||||||
|
MS_LOG(ERROR) << "cannot get ivs for client: " << iv_fl_id;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "fl_id : " << fl_id << "other id : " << *p_key;
|
if (iter->second.size() != IV_NUM) {
|
||||||
unsigned char secret1[SECRET_MAX_LEN] = {0};
|
MS_LOG(ERROR) << "get " << iter->second.size() << " ivs, the iv num required is: " << IV_NUM;
|
||||||
unsigned char salt[SECRET_MAX_LEN] = {0};
|
return false;
|
||||||
if (KeyAgreement::ComputeSharedKey(privKey1, pubKey1, SECRET_MAX_LEN, salt, SECRET_MAX_LEN, secret1) < 0) {
|
}
|
||||||
|
std::vector<uint8_t> pw_iv = iter->second[PW_IV_INDEX];
|
||||||
|
std::vector<uint8_t> pw_salt = iter->second[PW_SALT_INDEX];
|
||||||
|
PublicKey *pubKey = KeyAgreement::FromPublicBytes(public_key.data(), public_key.size());
|
||||||
|
if (pubKey == NULL) {
|
||||||
|
MS_LOG(ERROR) << "create pubKey failed\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "private_key fl_id : " << fl_id << " public_key fl_id : " << *p_key;
|
||||||
|
uint8_t secret1[SECRET_MAX_LEN] = {0};
|
||||||
|
if (KeyAgreement::ComputeSharedKey(privKey, pubKey, SECRET_MAX_LEN, pw_salt.data(), pw_salt.size(), secret1) <
|
||||||
|
0) {
|
||||||
MS_LOG(ERROR) << "ComputeSharedKey failed\n";
|
MS_LOG(ERROR) << "ComputeSharedKey failed\n";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> noise_tmp;
|
std::vector<float> noise_tmp;
|
||||||
if (Random::RandomAESCTR(&noise_tmp, cipher_init_->featuremap_, (const unsigned char *)secret1, SECRET_MAX_LEN) <
|
if (Masking::GetMasking(&noise_tmp, cipher_init_->featuremap_, (const uint8_t *)secret1, SECRET_MAX_LEN,
|
||||||
0) {
|
pw_iv.data(), pw_iv.size()) < 0) {
|
||||||
MS_LOG(ERROR) << "RandomAESCTR failed\n";
|
MS_LOG(ERROR) << "Get Masking failed\n";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool symbol_noise = GetSymbol(fl_id, *p_key);
|
bool symbol_noise = GetSymbol(fl_id, *p_key);
|
||||||
|
@ -345,9 +397,6 @@ bool CipherReconStruct::GetSuvNoise(
|
||||||
noise->at(index) += noise_tmp[index];
|
noise->at(index) += noise_tmp[index];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int i = 0; i < 5; i++) {
|
|
||||||
MS_LOG(INFO) << "index " << i << " : " << noise_tmp[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -361,37 +410,41 @@ bool CipherReconStruct::GetSymbol(const std::string &str1, const std::string &st
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
|
// recombined shares by their source fl_id (ownners)
|
||||||
|
bool CipherReconStruct::ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
|
||||||
std::map<std::string, std::vector<clientshare_str>> *des) {
|
std::map<std::string, std::vector<clientshare_str>> *des) {
|
||||||
for (auto iter_ori = src.begin(); iter_ori != src.end(); ++iter_ori) {
|
if (des == nullptr) return false;
|
||||||
std::string fl_des = iter_ori->first;
|
for (auto iter = src.begin(); iter != src.end(); ++iter) {
|
||||||
auto &cur_clientshare_str = iter_ori->second;
|
std::string des_id = iter->first;
|
||||||
|
auto &cur_clientshare_str = iter->second;
|
||||||
for (size_t index_clientshare = 0; index_clientshare < cur_clientshare_str.size(); ++index_clientshare) {
|
for (size_t index_clientshare = 0; index_clientshare < cur_clientshare_str.size(); ++index_clientshare) {
|
||||||
std::string fl_src = cur_clientshare_str[index_clientshare].fl_id;
|
std::string src_id = cur_clientshare_str[index_clientshare].fl_id;
|
||||||
clientshare_str value;
|
clientshare_str value;
|
||||||
value.fl_id = fl_des;
|
value.fl_id = des_id;
|
||||||
value.share = cur_clientshare_str[index_clientshare].share;
|
value.share = cur_clientshare_str[index_clientshare].share;
|
||||||
value.index = cur_clientshare_str[index_clientshare].index;
|
value.index = cur_clientshare_str[index_clientshare].index;
|
||||||
if (des->find(fl_src) == des->end()) { // fl_id_des is not in reconstruct_secret_list_
|
if (des->find(src_id) == des->end()) { // src_id is not in recombined shares list
|
||||||
std::vector<clientshare_str> value_list;
|
std::vector<clientshare_str> value_list;
|
||||||
value_list.push_back(value);
|
value_list.push_back(value);
|
||||||
des->insert(std::pair<std::string, std::vector<clientshare_str>>(fl_src, value_list));
|
des->insert(std::pair<std::string, std::vector<clientshare_str>>(src_id, value_list));
|
||||||
} else { // fl_id_des is in reconstruct_secret_list_
|
} else {
|
||||||
des->at(fl_src).push_back(value);
|
des->at(src_id).push_back(value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int shares_size) {
|
bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int shares_size) {
|
||||||
|
if (shares_tmp == nullptr) return false;
|
||||||
for (int i = 0; i < shares_size; ++i) {
|
for (int i = 0; i < shares_size; ++i) {
|
||||||
Share *share_i = new Share;
|
Share *share_i = new Share();
|
||||||
if (share_i == nullptr) {
|
if (share_i == nullptr) {
|
||||||
MS_LOG(ERROR) << "shares_tmp " << i << " memory to cipher is invalid.";
|
MS_LOG(ERROR) << "shares_tmp " << i << " memory to cipher is invalid.";
|
||||||
DeleteShares(shares_tmp);
|
DeleteShares(shares_tmp);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
share_i->data = new unsigned char[SHARE_MAX_SIZE];
|
share_i->data = new uint8_t[SHARE_MAX_SIZE];
|
||||||
if (share_i->data == nullptr) {
|
if (share_i->data == nullptr) {
|
||||||
MS_LOG(ERROR) << "shares_tmp's data " << i << " memory to cipher is invalid.";
|
MS_LOG(ERROR) << "shares_tmp's data " << i << " memory to cipher is invalid.";
|
||||||
DeleteShares(shares_tmp);
|
DeleteShares(shares_tmp);
|
||||||
|
@ -405,6 +458,7 @@ bool CipherReconStruct::MallocShares(std::vector<Share *> *shares_tmp, int share
|
||||||
}
|
}
|
||||||
|
|
||||||
void CipherReconStruct::DeleteShares(std::vector<Share *> *shares_tmp) {
|
void CipherReconStruct::DeleteShares(std::vector<Share *> *shares_tmp) {
|
||||||
|
if (shares_tmp == nullptr) return;
|
||||||
if (shares_tmp->size() != 0) {
|
if (shares_tmp->size() != 0) {
|
||||||
for (size_t i = 0; i < shares_tmp->size(); ++i) {
|
for (size_t i = 0; i < shares_tmp->size(); ++i) {
|
||||||
if (shares_tmp->at(i) != nullptr && shares_tmp->at(i)->data != nullptr) {
|
if (shares_tmp->at(i) != nullptr && shares_tmp->at(i)->data != nullptr) {
|
||||||
|
|
|
@ -28,6 +28,8 @@
|
||||||
#include "fl/armour/cipher/cipher_init.h"
|
#include "fl/armour/cipher/cipher_init.h"
|
||||||
#include "fl/armour/cipher/cipher_meta_storage.h"
|
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||||
|
|
||||||
|
#define IV_NUM 3
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
// The process of reconstruct secret mask in the secure aggregation
|
// The process of reconstruct secret mask in the secure aggregation
|
||||||
|
@ -44,7 +46,7 @@ class CipherReconStruct {
|
||||||
// reconstruct secret mask
|
// reconstruct secret mask
|
||||||
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
bool ReconstructSecrets(const int cur_iterator, const std::string &next_req_time,
|
||||||
const schema::SendReconstructSecret *reconstruct_secret_req,
|
const schema::SendReconstructSecret *reconstruct_secret_req,
|
||||||
const std::shared_ptr<fl::server::FBBuilder> &reconstruct_secret_resp_builder,
|
const std::shared_ptr<fl::server::FBBuilder> &fbb,
|
||||||
const std::vector<std::string> &client_list);
|
const std::vector<std::string> &client_list);
|
||||||
|
|
||||||
// build response code of reconstruct secret.
|
// build response code of reconstruct secret.
|
||||||
|
@ -60,26 +62,28 @@ class CipherReconStruct {
|
||||||
bool GetSymbol(const std::string &str1, const std::string &str2);
|
bool GetSymbol(const std::string &str1, const std::string &str2);
|
||||||
// get suv noise by computing shares result.
|
// get suv noise by computing shares result.
|
||||||
bool GetSuvNoise(const std::vector<std::string> &clients_share_list,
|
bool GetSuvNoise(const std::vector<std::string> &clients_share_list,
|
||||||
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
|
const std::map<std::string, std::vector<std::vector<uint8_t>>> &record_public_keys,
|
||||||
const string &fl_id, std::vector<float> *noise, uint8_t *secret, int length);
|
const std::map<std::string, std::vector<std::vector<uint8_t>>> &client_ivs, const string &fl_id,
|
||||||
|
std::vector<float> *noise, uint8_t *secret, int length);
|
||||||
// malloc shares.
|
// malloc shares.
|
||||||
bool MallocShares(std::vector<Share *> *shares_tmp, int shares_size);
|
bool MallocShares(std::vector<Share *> *shares_tmp, int shares_size);
|
||||||
// delete shares.
|
// delete shares.
|
||||||
void DeleteShares(std::vector<Share *> *shares_tmp);
|
void DeleteShares(std::vector<Share *> *shares_tmp);
|
||||||
// convert shares from receiving clients to sending clients.
|
// convert shares from receiving clients to sending clients.
|
||||||
void ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
|
bool ConvertSharesToShares(const std::map<std::string, std::vector<clientshare_str>> &src,
|
||||||
std::map<std::string, std::vector<clientshare_str>> *des);
|
std::map<std::string, std::vector<clientshare_str>> *des);
|
||||||
// generate noise from shares.
|
// generate noise from shares.
|
||||||
bool ReconstructSecretsGenNoise(const std::vector<string> &client_list);
|
bool ReconstructSecretsGenNoise(const std::vector<string> &client_list);
|
||||||
// get noise masks sum.
|
// get noise masks sum.
|
||||||
bool GetNoiseMasksSum(std::vector<float> *result, const std::map<std::string, std::vector<float>> &client_keys);
|
bool GetNoiseMasksSum(std::vector<float> *result, const std::map<std::string, std::vector<float>> &client_noise);
|
||||||
|
|
||||||
// combine noise mask.
|
// combine noise mask.
|
||||||
bool CombineMask(std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_keys,
|
bool CombineMask(std::vector<Share *> *shares_tmp, std::map<std::string, std::vector<float>> *client_noise,
|
||||||
const std::vector<std::string> &clients_share_list,
|
const std::vector<std::string> &clients_share_list,
|
||||||
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
|
const std::map<std::string, std::vector<std::vector<unsigned char>>> &record_public_keys,
|
||||||
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
|
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
|
||||||
const std::vector<string> &client_list);
|
const std::vector<string> &client_list,
|
||||||
|
const std::map<std::string, std::vector<std::vector<unsigned char>>> &client_ivs);
|
||||||
};
|
};
|
||||||
} // namespace armour
|
} // namespace armour
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -25,8 +25,8 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
||||||
const string next_req_time) {
|
const string next_req_time) {
|
||||||
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
|
MS_LOG(INFO) << "CipherShares::ShareSecrets START";
|
||||||
if (share_secrets_req == nullptr) {
|
if (share_secrets_req == nullptr) {
|
||||||
MS_LOG(ERROR) << "Request is nullptr or Response builder is nullptr.";
|
std::string reason = "Request is nullptr";
|
||||||
std::string reason = "Request is nullptr or Response builder is nullptr.";
|
MS_LOG(ERROR) << reason;
|
||||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time,
|
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError, reason, next_req_time,
|
||||||
cur_iterator);
|
cur_iterator);
|
||||||
return false;
|
return false;
|
||||||
|
@ -34,38 +34,34 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
||||||
|
|
||||||
// step 1: get client list and share secrets from memory server.
|
// step 1: get client list and share secrets from memory server.
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
|
|
||||||
|
int iteration = share_secrets_req->iteration();
|
||||||
|
std::vector<std::string> get_keys_clients;
|
||||||
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetKeysClientList, &get_keys_clients);
|
||||||
std::vector<std::string> clients_share_list;
|
std::vector<std::string> clients_share_list;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
||||||
&clients_share_list);
|
&clients_share_list);
|
||||||
std::vector<std::string> clients_exchange_list;
|
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList,
|
|
||||||
&clients_exchange_list);
|
|
||||||
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
||||||
&encrypted_shares_all);
|
&encrypted_shares_all);
|
||||||
|
|
||||||
MS_LOG(INFO) << "Client of keys size : " << clients_exchange_list.size()
|
MS_LOG(INFO) << "Client of get keys size : " << get_keys_clients.size()
|
||||||
<< "client of shares size : " << clients_share_list.size() << "shares size"
|
<< "client of update shares size : " << clients_share_list.size()
|
||||||
<< encrypted_shares_all.size();
|
<< "updated shares size: " << encrypted_shares_all.size();
|
||||||
if (encrypted_shares_all.size() != clients_share_list.size()) {
|
|
||||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime,
|
|
||||||
"client of shares and shares size are not equal", next_req_time, cur_iterator);
|
|
||||||
MS_LOG(ERROR) << "client of shares and shares size are not equal. client of shares size : "
|
|
||||||
<< clients_share_list.size() << "shares size" << encrypted_shares_all.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
// step 2: update new item to memory server. serialise: update pb struct to memory server.
|
// step 2: update new item to memory server. serialise: update pb struct to memory server.
|
||||||
int iteration = share_secrets_req->iteration();
|
|
||||||
std::string fl_id_src = share_secrets_req->fl_id()->str();
|
std::string fl_id_src = share_secrets_req->fl_id()->str();
|
||||||
if (find(clients_exchange_list.begin(), clients_exchange_list.end(), fl_id_src) ==
|
if (find(get_keys_clients.begin(), get_keys_clients.end(), fl_id_src) == get_keys_clients.end()) {
|
||||||
clients_exchange_list.end()) { // the client not in clients_exchange_list, return false.
|
// the client not in get keys clients
|
||||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError,
|
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_RequestError,
|
||||||
("client share secret is not in clients_exchange list. && client is illegal"), next_req_time,
|
("client share secret is not in getkeys list. && client is illegal"), next_req_time,
|
||||||
iteration);
|
iteration);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (find(clients_share_list.begin(), clients_share_list.end(), fl_id_src) !=
|
|
||||||
clients_share_list.end()) { // the client is already exists, return false.
|
if (encrypted_shares_all.find(fl_id_src) != encrypted_shares_all.end()) { // the client is already exists
|
||||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED,
|
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED,
|
||||||
("client sharesecret already exists."), next_req_time, iteration);
|
("client sharesecret already exists."), next_req_time, iteration);
|
||||||
return false;
|
return false;
|
||||||
|
@ -74,24 +70,22 @@ bool CipherShares::ShareSecrets(const int cur_iterator, const schema::RequestSha
|
||||||
// update new item to memory server.
|
// update new item to memory server.
|
||||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
|
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::ClientShare>> *encrypted_shares =
|
||||||
(share_secrets_req->encrypted_shares());
|
(share_secrets_req->encrypted_shares());
|
||||||
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
|
||||||
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
|
|
||||||
bool retcode_client =
|
bool retcode_client =
|
||||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src);
|
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxShareSecretsClientList, fl_id_src);
|
||||||
bool retcode = retcode_share && retcode_client;
|
bool retcode_share = cipher_init_->cipher_meta_storage_.UpdateClientShareToServer(
|
||||||
if (retcode) {
|
fl::server::kCtxClientsEncryptedShares, fl_id_src, encrypted_shares);
|
||||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
|
if (!(retcode_share && retcode_client)) {
|
||||||
MS_LOG(INFO) << "CipherShares::ShareSecrets Success";
|
|
||||||
} else {
|
|
||||||
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime,
|
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_OutOfTime,
|
||||||
"update client of shares and shares failed", next_req_time, iteration);
|
"update client of shares and shares failed", next_req_time, iteration);
|
||||||
MS_LOG(ERROR) << "CipherShares::ShareSecrets update client of shares and shares failed ";
|
MS_LOG(ERROR) << "CipherShares::ShareSecrets update client of shares and shares failed ";
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
BuildShareSecretsRsp(share_secrets_resp_builder, schema::ResponseCode_SUCCEED, "OK", next_req_time, iteration);
|
||||||
|
MS_LOG(INFO) << "CipherShares::ShareSecrets Success";
|
||||||
clock_t end_time = clock();
|
clock_t end_time = clock();
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
MS_LOG(INFO) << "ShareSecrets get + deal + update data time is : " << duration;
|
MS_LOG(INFO) << "ShareSecrets get + deal + update data time is : " << duration;
|
||||||
return retcode;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||||
|
@ -101,36 +95,38 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
// step 0: check whether the parameters are legal.
|
// step 0: check whether the parameters are legal.
|
||||||
if (get_secrets_req == nullptr) {
|
if (get_secrets_req == nullptr) {
|
||||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, 0);
|
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SystemError, 0, next_req_time, nullptr);
|
||||||
MS_LOG(ERROR) << "GetSecrets: get_secrets_req is nullptr.";
|
MS_LOG(ERROR) << "GetSecrets: get_secrets_req is nullptr.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// step 1: get client list and client shares list from memory server.
|
// step 1: get client list and client shares list from memory server.
|
||||||
std::vector<std::string> clients_share_list;
|
|
||||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxShareSecretsClientList,
|
|
||||||
&clients_share_list);
|
|
||||||
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
std::map<std::string, std::vector<clientshare_str>> encrypted_shares_all;
|
||||||
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
cipher_init_->cipher_meta_storage_.GetClientSharesFromServer(fl::server::kCtxClientsEncryptedShares,
|
||||||
&encrypted_shares_all);
|
&encrypted_shares_all);
|
||||||
int iteration = get_secrets_req->iteration();
|
int iteration = get_secrets_req->iteration();
|
||||||
size_t share_clients_num = clients_share_list.size();
|
size_t encrypted_shares_num = encrypted_shares_all.size();
|
||||||
size_t cients_has_shares = encrypted_shares_all.size();
|
if (cipher_init_->share_secrets_threshold > encrypted_shares_num) { // the client num is not enough, return false.
|
||||||
if (share_clients_num != cients_has_shares) {
|
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
|
||||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_OutOfTime, iteration, next_req_time, 0);
|
MS_LOG(INFO) << "GetSecrets: the encrypted shares num is not enough: share_secrets_threshold: "
|
||||||
MS_LOG(ERROR) << "cients_has_shares: " << cients_has_shares << "share_clients_num: " << share_clients_num;
|
<< cipher_init_->share_secrets_threshold << "encrypted_shares_num: " << encrypted_shares_num;
|
||||||
}
|
|
||||||
if (cipher_init_->share_clients_num_need_ > share_clients_num) { // the client num is not enough, return false.
|
|
||||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, 0);
|
|
||||||
MS_LOG(INFO) << "GetSecrets: the client num is not enough: share_clients_num_need_: "
|
|
||||||
<< cipher_init_->share_clients_num_need_ << "share_clients_num: " << share_clients_num;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string fl_id = get_secrets_req->fl_id()->str();
|
std::string fl_id = get_secrets_req->fl_id()->str();
|
||||||
if (find(clients_share_list.begin(), clients_share_list.end(), fl_id) ==
|
// the client is not in share secrets client list.
|
||||||
clients_share_list.end()) { // the client is not in client list, return false.
|
if (encrypted_shares_all.find(fl_id) == encrypted_shares_all.end()) {
|
||||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, 0);
|
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_RequestError, iteration, next_req_time, nullptr);
|
||||||
MS_LOG(ERROR) << "GetSecrets: client is not in client list.";
|
MS_LOG(ERROR) << "GetSecrets: client is not in share secrets client list.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool retcode_client =
|
||||||
|
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetSecretsClientList, fl_id);
|
||||||
|
if (!retcode_client) {
|
||||||
|
MS_LOG(ERROR) << "update get secrets clients failed";
|
||||||
|
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SucNotReady, iteration, next_req_time, nullptr);
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the result client shares.
|
// get the result client shares.
|
||||||
|
@ -171,7 +167,6 @@ bool CipherShares::GetSecrets(const schema::GetShareSecrets *get_secrets_req,
|
||||||
|
|
||||||
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SUCCEED, iteration, next_req_time,
|
BuildGetSecretsRsp(get_secrets_resp_builder, schema::ResponseCode_SUCCEED, iteration, next_req_time,
|
||||||
&encrypted_shares);
|
&encrypted_shares);
|
||||||
|
|
||||||
MS_LOG(INFO) << "CipherShares::GetSecrets Success";
|
MS_LOG(INFO) << "CipherShares::GetSecrets Success";
|
||||||
clock_t end_time = clock();
|
clock_t end_time = clock();
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
|
@ -185,7 +180,7 @@ void CipherShares::BuildGetSecretsRsp(
|
||||||
int rsp_retcode = retcode;
|
int rsp_retcode = retcode;
|
||||||
int rsp_iteration = iteration;
|
int rsp_iteration = iteration;
|
||||||
auto rsp_next_req_time = get_secrets_resp_builder->CreateString(next_req_time);
|
auto rsp_next_req_time = get_secrets_resp_builder->CreateString(next_req_time);
|
||||||
if (encrypted_shares == 0) {
|
if (encrypted_shares == nullptr) {
|
||||||
auto get_secrets_rsp =
|
auto get_secrets_rsp =
|
||||||
schema::CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration, 0, rsp_next_req_time);
|
schema::CreateReturnShareSecrets(*get_secrets_resp_builder, rsp_retcode, rsp_iteration, 0, rsp_next_req_time);
|
||||||
get_secrets_resp_builder->Finish(get_secrets_rsp);
|
get_secrets_resp_builder->Finish(get_secrets_rsp);
|
||||||
|
@ -195,7 +190,6 @@ void CipherShares::BuildGetSecretsRsp(
|
||||||
encrypted_shares_rsp, rsp_next_req_time);
|
encrypted_shares_rsp, rsp_next_req_time);
|
||||||
get_secrets_resp_builder->Finish(get_secrets_rsp);
|
get_secrets_resp_builder->Finish(get_secrets_rsp);
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,8 +26,8 @@ bool CipherUnmask::UnMask(const std::map<std::string, AddressPtr> &data) {
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
std::vector<float> noise;
|
std::vector<float> noise;
|
||||||
|
|
||||||
(void)cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
|
bool ret = cipher_init_->cipher_meta_storage_.GetClientNoisesFromServer(fl::server::kCtxClientNoises, &noise);
|
||||||
if (noise.size() != cipher_init_->featuremap_) {
|
if (!ret || noise.size() != cipher_init_->featuremap_) {
|
||||||
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
|
MS_LOG(ERROR) << " CipherMgr UnMask ERROR";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,12 +18,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
|
AESEncrypt::AESEncrypt(const uint8_t *key, int key_len, const uint8_t *ivec, int ivec_len, const AES_MODE mode) {
|
||||||
#define KEY_STEP_MAX 32
|
|
||||||
#define KEY_STEP_MIN 16
|
|
||||||
#define PAD_SIZE 5
|
|
||||||
|
|
||||||
AESEncrypt::AESEncrypt(const unsigned char *key, int key_len, unsigned char *ivec, int ivec_len, const AES_MODE mode) {
|
|
||||||
privKey = key;
|
privKey = key;
|
||||||
privKeyLen = key_len;
|
privKeyLen = key_len;
|
||||||
iVec = ivec;
|
iVec = ivec;
|
||||||
|
@ -47,12 +42,20 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
|
||||||
#else
|
#else
|
||||||
int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) {
|
int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len) {
|
||||||
int ret;
|
int ret;
|
||||||
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) {
|
if (privKey == NULL || iVec == NULL) {
|
||||||
MS_LOG(ERROR) << "key length must be 16 or 32!";
|
MS_LOG(ERROR) << "private key or init vector is invalid.";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (iVecLen != INIT_VEC_SIZE) {
|
if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
|
||||||
MS_LOG(ERROR) << "initial vector size must be 16!";
|
MS_LOG(ERROR) << "key length is invalid.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (iVecLen != AES_IV_SIZE) {
|
||||||
|
MS_LOG(ERROR) << "initial vector size is invalid.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (data == NULL || len <= 0 || encrypt_data == NULL) {
|
||||||
|
MS_LOG(ERROR) << "input data is invalid.";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (aesMode == AES_CBC || aesMode == AES_CTR) {
|
if (aesMode == AES_CBC || aesMode == AES_CTR) {
|
||||||
|
@ -69,12 +72,20 @@ int AESEncrypt::EncryptData(const unsigned char *data, const int len, unsigned c
|
||||||
|
|
||||||
int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) {
|
int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len) {
|
||||||
int ret = 0;
|
int ret = 0;
|
||||||
if (privKeyLen != KEY_STEP_MIN && privKeyLen != KEY_STEP_MAX) {
|
if (privKey == NULL || iVec == NULL) {
|
||||||
MS_LOG(ERROR) << "key length must be 16 or 32!";
|
MS_LOG(ERROR) << "private key or init vector is invalid.";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (iVecLen != INIT_VEC_SIZE) {
|
if (privKeyLen != KEY_LENGTH_16 && privKeyLen != KEY_LENGTH_32) {
|
||||||
MS_LOG(ERROR) << "initial vector size must be 16!";
|
MS_LOG(ERROR) << "key length is invalid.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (iVecLen != AES_IV_SIZE) {
|
||||||
|
MS_LOG(ERROR) << "initial vector size is invalid.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (data == NULL || encrypt_len <= 0 || encrypt_data == NULL) {
|
||||||
|
MS_LOG(ERROR) << "input data is invalid.";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (aesMode == AES_CBC || aesMode == AES_CTR) {
|
if (aesMode == AES_CBC || aesMode == AES_CTR) {
|
||||||
|
@ -88,17 +99,21 @@ int AESEncrypt::DecryptData(const unsigned char *encrypt_data, const int encrypt
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const unsigned char *key, unsigned char *ivec,
|
int AESEncrypt::evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
|
||||||
unsigned char *encrypt_data, int *encrypt_len) {
|
uint8_t *encrypt_data, int *encrypt_len) {
|
||||||
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
||||||
|
if (ctx == NULL) {
|
||||||
|
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
int out_len;
|
int out_len;
|
||||||
int ret = 0;
|
int ret;
|
||||||
if (aesMode == AES_CBC) {
|
if (aesMode == AES_CBC) {
|
||||||
switch (privKeyLen) {
|
switch (privKeyLen) {
|
||||||
case 16:
|
case KEY_LENGTH_16:
|
||||||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
|
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case KEY_LENGTH_32:
|
||||||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
|
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -107,16 +122,16 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
|
||||||
}
|
}
|
||||||
if (ret != 1) {
|
if (ret != 1) {
|
||||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex CBC fail!";
|
MS_LOG(ERROR) << "EVP_EncryptInit_ex CBC fail!";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
EVP_CIPHER_CTX_set_key_length(ctx, EVP_MAX_KEY_LENGTH);
|
EVP_CIPHER_CTX_set_padding(ctx, EVP_PADDING_PKCS7);
|
||||||
EVP_CIPHER_CTX_set_padding(ctx, PAD_SIZE);
|
|
||||||
} else if (aesMode == AES_CTR) {
|
} else if (aesMode == AES_CTR) {
|
||||||
switch (privKeyLen) {
|
switch (privKeyLen) {
|
||||||
case 16:
|
case KEY_LENGTH_16:
|
||||||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
|
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case KEY_LENGTH_32:
|
||||||
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
|
ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -125,21 +140,25 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
|
||||||
}
|
}
|
||||||
if (ret != 1) {
|
if (ret != 1) {
|
||||||
MS_LOG(ERROR) << "EVP_EncryptInit_ex CTR fail!";
|
MS_LOG(ERROR) << "EVP_EncryptInit_ex CTR fail!";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "Unsupported AES mode";
|
MS_LOG(ERROR) << "Unsupported AES mode";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
ret = EVP_EncryptUpdate(ctx, encrypt_data, &out_len, data, len);
|
ret = EVP_EncryptUpdate(ctx, encrypt_data, &out_len, data, len);
|
||||||
if (ret != 1) {
|
if (ret != 1) {
|
||||||
MS_LOG(ERROR) << "EVP_EncryptUpdate fail!";
|
MS_LOG(ERROR) << "EVP_EncryptUpdate fail!";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
*encrypt_len = out_len;
|
*encrypt_len = out_len;
|
||||||
ret = EVP_EncryptFinal_ex(ctx, encrypt_data + *encrypt_len, &out_len);
|
ret = EVP_EncryptFinal_ex(ctx, encrypt_data + *encrypt_len, &out_len);
|
||||||
if (ret != 1) {
|
if (ret != 1) {
|
||||||
MS_LOG(ERROR) << "EVP_EncryptFinal_ex fail!";
|
MS_LOG(ERROR) << "EVP_EncryptFinal_ex fail!";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
*encrypt_len += out_len;
|
*encrypt_len += out_len;
|
||||||
|
@ -147,17 +166,21 @@ int AESEncrypt::evp_aes_encrypt(const unsigned char *data, const int len, const
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len, const unsigned char *key,
|
int AESEncrypt::evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
|
||||||
unsigned char *ivec, unsigned char *decrypt_data, int *decrypt_len) {
|
uint8_t *decrypt_data, int *decrypt_len) {
|
||||||
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
|
||||||
|
if (ctx == NULL) {
|
||||||
|
MS_LOG(ERROR) << "EVP_CIPHER_CTX_new fail!";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
int out_len;
|
int out_len;
|
||||||
int ret = 0;
|
int ret;
|
||||||
if (aesMode == AES_CBC) {
|
if (aesMode == AES_CBC) {
|
||||||
switch (privKeyLen) {
|
switch (privKeyLen) {
|
||||||
case 16:
|
case KEY_LENGTH_16:
|
||||||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
|
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_cbc(), NULL, key, ivec);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case KEY_LENGTH_32:
|
||||||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
|
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, ivec);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -165,40 +188,46 @@ int AESEncrypt::evp_aes_decrypt(const unsigned char *encrypt_data, const int len
|
||||||
ret = -1;
|
ret = -1;
|
||||||
}
|
}
|
||||||
if (ret != 1) {
|
if (ret != 1) {
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
EVP_CIPHER_CTX_set_key_length(ctx, EVP_MAX_KEY_LENGTH);
|
|
||||||
} else if (aesMode == AES_CTR) {
|
} else if (aesMode == AES_CTR) {
|
||||||
switch (privKeyLen) {
|
switch (privKeyLen) {
|
||||||
case 16:
|
case KEY_LENGTH_16:
|
||||||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
|
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_ctr(), NULL, key, ivec);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case KEY_LENGTH_32:
|
||||||
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
|
ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_ctr(), NULL, key, ivec);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(ERROR) << "key length is incorrect!";
|
MS_LOG(ERROR) << "key length is incorrect!";
|
||||||
ret = -1;
|
ret = -1;
|
||||||
}
|
}
|
||||||
|
if (ret != 1) {
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
ret = -1;
|
MS_LOG(ERROR) << "Unsupported AES mode";
|
||||||
}
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
|
||||||
if (ret != 1) {
|
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = EVP_DecryptUpdate(ctx, decrypt_data, &out_len, encrypt_data, len);
|
ret = EVP_DecryptUpdate(ctx, decrypt_data, &out_len, encrypt_data, len);
|
||||||
if (ret != 1) {
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptUpdate fail!";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
*decrypt_len = out_len;
|
*decrypt_len = out_len;
|
||||||
ret = EVP_DecryptFinal_ex(ctx, decrypt_data + *decrypt_len, &out_len);
|
ret = EVP_DecryptFinal_ex(ctx, decrypt_data + *decrypt_len, &out_len);
|
||||||
if (ret != 1) {
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "EVP_DecryptFinal_ex fail!";
|
||||||
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
*decrypt_len += out_len;
|
*decrypt_len += out_len;
|
||||||
EVP_CIPHER_CTX_free(ctx);
|
EVP_CIPHER_CTX_free(ctx);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -22,7 +22,9 @@
|
||||||
#endif
|
#endif
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
#define INIT_VEC_SIZE 16
|
#define AES_IV_SIZE 16
|
||||||
|
#define KEY_LENGTH_32 32
|
||||||
|
#define KEY_LENGTH_16 16
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
|
@ -35,21 +37,21 @@ class SymmetricEncrypt : Encrypt {};
|
||||||
|
|
||||||
class AESEncrypt : SymmetricEncrypt {
|
class AESEncrypt : SymmetricEncrypt {
|
||||||
public:
|
public:
|
||||||
AESEncrypt(const unsigned char *key, int key_len, unsigned char *ivec, int ivec_len, AES_MODE mode);
|
AESEncrypt(const unsigned char *key, int key_len, const uint8_t *ivec, int ivec_len, AES_MODE mode);
|
||||||
~AESEncrypt();
|
~AESEncrypt();
|
||||||
int EncryptData(const unsigned char *data, const int len, unsigned char *encrypt_data, int *encrypt_len);
|
int EncryptData(const uint8_t *data, const int len, uint8_t *encrypt_data, int *encrypt_len);
|
||||||
int DecryptData(const unsigned char *encrypt_data, const int encrypt_len, unsigned char *data, int *len);
|
int DecryptData(const uint8_t *encrypt_data, const int encrypt_len, uint8_t *data, int *len);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const unsigned char *privKey;
|
const uint8_t *privKey;
|
||||||
int privKeyLen;
|
int privKeyLen;
|
||||||
unsigned char *iVec;
|
const uint8_t *iVec;
|
||||||
int iVecLen;
|
int iVecLen;
|
||||||
AES_MODE aesMode;
|
AES_MODE aesMode;
|
||||||
int evp_aes_encrypt(const unsigned char *data, const int len, const unsigned char *key, unsigned char *ivec,
|
int evp_aes_encrypt(const uint8_t *data, const int len, const uint8_t *key, const uint8_t *ivec,
|
||||||
unsigned char *encrypt_data, int *encrypt_len);
|
uint8_t *encrypt_data, int *encrypt_len);
|
||||||
int evp_aes_decrypt(const unsigned char *encrypt_data, const int len, const unsigned char *key, unsigned char *ivec,
|
int evp_aes_decrypt(const uint8_t *encrypt_data, const int len, const uint8_t *key, const uint8_t *ivec,
|
||||||
unsigned char *decrypt_data, int *decrypt_len);
|
uint8_t *decrypt_data, int *decrypt_len);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace armour
|
} // namespace armour
|
||||||
|
|
|
@ -54,14 +54,22 @@ PrivateKey::PrivateKey(EVP_PKEY *evpKey) { evpPrivKey = evpKey; }
|
||||||
|
|
||||||
PrivateKey::~PrivateKey() { EVP_PKEY_free(evpPrivKey); }
|
PrivateKey::~PrivateKey() { EVP_PKEY_free(evpPrivKey); }
|
||||||
|
|
||||||
int PrivateKey::GetPrivateBytes(size_t *len, unsigned char *privKeyBytes) {
|
int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) {
|
||||||
|
if (privKeyBytes == nullptr || len <= 0) {
|
||||||
|
MS_LOG(ERROR) << "input privKeyBytes invalid.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
if (!EVP_PKEY_get_raw_private_key(evpPrivKey, privKeyBytes, len)) {
|
if (!EVP_PKEY_get_raw_private_key(evpPrivKey, privKeyBytes, len)) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int PrivateKey::GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) {
|
int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) {
|
||||||
|
if (pubKeyBytes == nullptr || len <= 0) {
|
||||||
|
MS_LOG(ERROR) << "input pubKeyBytes invalid.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
if (!EVP_PKEY_get_raw_public_key(evpPrivKey, pubKeyBytes, len)) {
|
if (!EVP_PKEY_get_raw_public_key(evpPrivKey, pubKeyBytes, len)) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
@ -69,7 +77,19 @@ int PrivateKey::GetPublicBytes(size_t *len, unsigned char *pubKeyBytes) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned char *salt, int salt_len,
|
int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned char *salt, int salt_len,
|
||||||
unsigned char *exchangeKey) {
|
uint8_t *exchangeKey) {
|
||||||
|
if (peerPublicKey == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "peerPublicKey is nullptr.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (key_len != KEY_LEN || exchangeKey == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "exchangeKey is nullptr or input key_len is incorrect.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (salt == nullptr || salt_len != SALT_LEN) {
|
||||||
|
MS_LOG(ERROR) << "input salt in invalid.";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
EVP_PKEY_CTX *ctx;
|
EVP_PKEY_CTX *ctx;
|
||||||
size_t len = 0;
|
size_t len = 0;
|
||||||
ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL);
|
ctx = EVP_PKEY_CTX_new(evpPrivKey, NULL);
|
||||||
|
@ -79,30 +99,38 @@ int PrivateKey::Exchange(PublicKey *peerPublicKey, int key_len, const unsigned c
|
||||||
}
|
}
|
||||||
if (EVP_PKEY_derive_init(ctx) <= 0) {
|
if (EVP_PKEY_derive_init(ctx) <= 0) {
|
||||||
MS_LOG(ERROR) << "EVP_PKEY_derive_init failed!";
|
MS_LOG(ERROR) << "EVP_PKEY_derive_init failed!";
|
||||||
|
EVP_PKEY_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
if (EVP_PKEY_derive_set_peer(ctx, peerPublicKey->evpPubKey) <= 0) {
|
if (EVP_PKEY_derive_set_peer(ctx, peerPublicKey->evpPubKey) <= 0) {
|
||||||
MS_LOG(ERROR) << "EVP_PKEY_derive_set_peer failed!";
|
MS_LOG(ERROR) << "EVP_PKEY_derive_set_peer failed!";
|
||||||
|
EVP_PKEY_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
unsigned char *secret;
|
unsigned char *secret;
|
||||||
if (EVP_PKEY_derive(ctx, NULL, &len) <= 0) {
|
if (EVP_PKEY_derive(ctx, NULL, &len) <= 0) {
|
||||||
MS_LOG(ERROR) << "get derive key size failed!";
|
MS_LOG(ERROR) << "get derive key size failed!";
|
||||||
|
EVP_PKEY_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
secret = (unsigned char *)OPENSSL_malloc(len);
|
secret = (unsigned char *)OPENSSL_malloc(len);
|
||||||
if (!secret) {
|
if (!secret) {
|
||||||
MS_LOG(ERROR) << "malloc secret memory failed!";
|
MS_LOG(ERROR) << "malloc secret memory failed!";
|
||||||
|
EVP_PKEY_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (EVP_PKEY_derive(ctx, secret, &len) <= 0) {
|
if (EVP_PKEY_derive(ctx, secret, &len) <= 0) {
|
||||||
MS_LOG(ERROR) << "derive key failed!";
|
MS_LOG(ERROR) << "derive key failed!";
|
||||||
|
OPENSSL_free(secret);
|
||||||
|
EVP_PKEY_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
if (!PKCS5_PBKDF2_HMAC(reinterpret_cast<char *>(secret), len, salt, salt_len, ITERATION, EVP_sha256(), key_len,
|
||||||
if (!PKCS5_PBKDF2_HMAC((char *)secret, len, salt, salt_len, ITERATION, EVP_sha256(), key_len, exchangeKey)) {
|
exchangeKey)) {
|
||||||
|
OPENSSL_free(secret);
|
||||||
|
EVP_PKEY_CTX_free(ctx);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
OPENSSL_free(secret);
|
OPENSSL_free(secret);
|
||||||
|
@ -118,9 +146,11 @@ PrivateKey *KeyAgreement::GeneratePrivKey() {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
if (EVP_PKEY_keygen_init(pctx) <= 0) {
|
if (EVP_PKEY_keygen_init(pctx) <= 0) {
|
||||||
|
EVP_PKEY_CTX_free(pctx);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
if (EVP_PKEY_keygen(pctx, &evpKey) <= 0) {
|
if (EVP_PKEY_keygen(pctx, &evpKey) <= 0) {
|
||||||
|
EVP_PKEY_CTX_free(pctx);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
EVP_PKEY_CTX_free(pctx);
|
EVP_PKEY_CTX_free(pctx);
|
||||||
|
@ -131,14 +161,30 @@ PrivateKey *KeyAgreement::GeneratePrivKey() {
|
||||||
PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
|
PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
|
||||||
unsigned char *pubKeyBytes;
|
unsigned char *pubKeyBytes;
|
||||||
size_t len = 0;
|
size_t len = 0;
|
||||||
|
if (privKey == nullptr) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, NULL, &len)) {
|
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, NULL, &len)) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
pubKeyBytes = (unsigned char *)OPENSSL_malloc(len);
|
pubKeyBytes = reinterpret_cast<uint8_t *>(OPENSSL_malloc(len));
|
||||||
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, pubKeyBytes, &len)) {
|
if (!pubKeyBytes) {
|
||||||
|
MS_LOG(ERROR) << "malloc secret memory failed!";
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!EVP_PKEY_get_raw_public_key(privKey->evpPrivKey, pubKeyBytes, &len)) {
|
||||||
|
MS_LOG(ERROR) << "EVP_PKEY_get_raw_public_key failed!";
|
||||||
|
OPENSSL_free(pubKeyBytes);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
EVP_PKEY *evp_pubKey =
|
||||||
|
EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, reinterpret_cast<uint8_t *>(pubKeyBytes), len);
|
||||||
|
if (evp_pubKey == NULL) {
|
||||||
|
MS_LOG(ERROR) << "EVP_PKEY_new_raw_public_key failed!";
|
||||||
|
OPENSSL_free(pubKeyBytes);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
EVP_PKEY *evp_pubKey = EVP_PKEY_new_raw_public_key(EVP_PKEY_X25519, NULL, (unsigned char *)pubKeyBytes, len);
|
|
||||||
OPENSSL_free(pubKeyBytes);
|
OPENSSL_free(pubKeyBytes);
|
||||||
PublicKey *pubKey = new PublicKey(evp_pubKey);
|
PublicKey *pubKey = new PublicKey(evp_pubKey);
|
||||||
return pubKey;
|
return pubKey;
|
||||||
|
@ -147,6 +193,7 @@ PublicKey *KeyAgreement::GeneratePubKey(PrivateKey *privKey) {
|
||||||
PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) {
|
PrivateKey *KeyAgreement::FromPrivateBytes(unsigned char *data, int len) {
|
||||||
EVP_PKEY *evp_Key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, NULL, data, len);
|
EVP_PKEY *evp_Key = EVP_PKEY_new_raw_private_key(EVP_PKEY_X25519, NULL, data, len);
|
||||||
if (evp_Key == NULL) {
|
if (evp_Key == NULL) {
|
||||||
|
MS_LOG(ERROR) << "create evp_Key from raw bytes failed!";
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
PrivateKey *privKey = new PrivateKey(evp_Key);
|
PrivateKey *privKey = new PrivateKey(evp_Key);
|
||||||
|
@ -164,7 +211,11 @@ PublicKey *KeyAgreement::FromPublicBytes(unsigned char *data, int len) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int KeyAgreement::ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey, int key_len,
|
int KeyAgreement::ComputeSharedKey(PrivateKey *privKey, PublicKey *peerPublicKey, int key_len,
|
||||||
const unsigned char *salt, int salt_len, unsigned char *exchangeKey) {
|
const unsigned char *salt, int salt_len, uint8_t *exchangeKey) {
|
||||||
|
if (privKey == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "privKey is nullptr!";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
return privKey->Exchange(peerPublicKey, key_len, salt, salt_len, exchangeKey);
|
return privKey->Exchange(peerPublicKey, key_len, salt, salt_len, exchangeKey);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -24,7 +24,8 @@
|
||||||
#endif
|
#endif
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
#define KEK_KEY_LEN 32
|
#define KEY_LEN 32
|
||||||
|
#define SALT_LEN 32
|
||||||
#define ITERATION 10000
|
#define ITERATION 10000
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
|
|
@ -14,44 +14,39 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "fl/armour/secure_protocol/random.h"
|
#include "fl/armour/secure_protocol/masking.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
Random::Random(size_t init_seed) { generator.seed(init_seed); }
|
|
||||||
|
|
||||||
Random::~Random() {}
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
|
int Masking::GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *seed, int seed_len,
|
||||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
const uint8_t *ivec, int ivec_size) {
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
|
|
||||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
int Random::GetRandomBytes(unsigned char *secret, int num_bytes) {
|
int Masking::GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *secret, int secret_len,
|
||||||
int retval = RAND_priv_bytes(secret, num_bytes);
|
const uint8_t *ivec, int ivec_size) {
|
||||||
return retval;
|
if ((secret_len != KEY_LENGTH_16 && secret_len != KEY_LENGTH_32) || secret == NULL) {
|
||||||
}
|
MS_LOG(ERROR) << "secret is invalid!";
|
||||||
|
return -1;
|
||||||
int Random::RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len) {
|
}
|
||||||
if (seed_len != 16 && seed_len != 32) {
|
if (noise == NULL || noise_len <= 0) {
|
||||||
MS_LOG(ERROR) << "seed length must be 16 or 32!";
|
MS_LOG(ERROR) << "noise is invalid!";
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (ivec == NULL || ivec_size != AES_IV_SIZE) {
|
||||||
|
MS_LOG(ERROR) << "ivec is invalid!";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
int size = noise_len * sizeof(int);
|
int size = noise_len * sizeof(int);
|
||||||
std::vector<unsigned char> data(size, 0);
|
std::vector<uint8_t> data(size, 0);
|
||||||
std::vector<unsigned char> encrypt_data(size, 0);
|
std::vector<uint8_t> encrypt_data(size, 0);
|
||||||
std::vector<unsigned char> ivec(INIT_VEC_SIZE, 0);
|
|
||||||
int encrypt_len = 0;
|
int encrypt_len = 0;
|
||||||
AESEncrypt encrypt(seed, seed_len, ivec.data(), INIT_VEC_SIZE, AES_CTR);
|
AESEncrypt encrypt(secret, secret_len, ivec, AES_IV_SIZE, AES_CTR);
|
||||||
if (encrypt.EncryptData(data.data(), size, encrypt_data.data(), &encrypt_len) != 0) {
|
if (encrypt.EncryptData(data.data(), size, encrypt_data.data(), &encrypt_len) != 0) {
|
||||||
MS_LOG(ERROR) << "call encryptData fail!";
|
MS_LOG(ERROR) << "call AES-CTR failed!";
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,27 +19,15 @@
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#ifndef _WIN32
|
|
||||||
#include <openssl/rand.h>
|
|
||||||
#endif
|
|
||||||
#include "fl/armour/secure_protocol/encrypt.h"
|
#include "fl/armour/secure_protocol/encrypt.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
|
|
||||||
#define RANDOM_LEN 8
|
class Masking {
|
||||||
|
|
||||||
class Random {
|
|
||||||
public:
|
public:
|
||||||
explicit Random(size_t init_seed);
|
static int GetMasking(std::vector<float> *noise, int noise_len, const uint8_t *secret, int secret_len,
|
||||||
~Random();
|
const uint8_t *ivec, int ivec_size);
|
||||||
// use openssl RAND_priv_bytes
|
|
||||||
static int GetRandomBytes(unsigned char *secret, int num_bytes);
|
|
||||||
|
|
||||||
static int RandomAESCTR(std::vector<float> *noise, int noise_len, const unsigned char *seed, int seed_len);
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::default_random_engine generator;
|
|
||||||
};
|
};
|
||||||
} // namespace armour
|
} // namespace armour
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
|
@ -62,6 +62,11 @@ struct RoundConfig {
|
||||||
struct CipherConfig {
|
struct CipherConfig {
|
||||||
float share_secrets_ratio = 1.0;
|
float share_secrets_ratio = 1.0;
|
||||||
uint64_t cipher_time_window = 300000;
|
uint64_t cipher_time_window = 300000;
|
||||||
|
size_t exchange_keys_threshold = 0;
|
||||||
|
size_t get_keys_threshold = 0;
|
||||||
|
size_t share_secrets_threshold = 0;
|
||||||
|
size_t get_secrets_threshold = 0;
|
||||||
|
size_t client_list_threshold = 0;
|
||||||
size_t reconstruct_secrets_threshold = 0;
|
size_t reconstruct_secrets_threshold = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -207,8 +212,11 @@ constexpr auto kCtxClientNoises = "clients_noises";
|
||||||
constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";
|
constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";
|
||||||
constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares";
|
constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares";
|
||||||
constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list";
|
constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list";
|
||||||
|
constexpr auto kCtxGetSecretsClientList = "get_secrets_client_list";
|
||||||
constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
|
constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
|
||||||
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
|
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
|
||||||
|
constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list";
|
||||||
|
constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
|
||||||
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
|
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
|
||||||
constexpr auto kCtxCipherPrimer = "cipher_primer";
|
constexpr auto kCtxCipherPrimer = "cipher_primer";
|
||||||
|
|
||||||
|
|
|
@ -41,108 +41,102 @@ void ClientListKernel::InitKernel(size_t) {
|
||||||
|
|
||||||
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
|
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
|
||||||
std::shared_ptr<server::FBBuilder> fbb) {
|
std::shared_ptr<server::FBBuilder> fbb) {
|
||||||
bool response = false;
|
|
||||||
std::vector<string> client_list;
|
std::vector<string> client_list;
|
||||||
|
std::vector<string> empty_client_list;
|
||||||
std::string fl_id = get_clients_req->fl_id()->str();
|
std::string fl_id = get_clients_req->fl_id()->str();
|
||||||
int32_t iter_client = (size_t)get_clients_req->iteration();
|
|
||||||
if (iter_num != (size_t)iter_client) {
|
if (!LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
||||||
MS_LOG(ERROR) << "ClientListKernel iteration invalid. servertime is " << iter_num;
|
MS_LOG(ERROR) << "update_model_client_threshold is not set.";
|
||||||
MS_LOG(ERROR) << "ClientListKernel iteration invalid. clienttime is " << iter_client;
|
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_threshold is not set.",
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
|
empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
return false;
|
||||||
} else {
|
|
||||||
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
|
||||||
uint64_t update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
|
||||||
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
|
||||||
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
|
||||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
|
||||||
client_list.push_back(client_list_pb.fl_id(i));
|
|
||||||
}
|
|
||||||
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) { // client in client_list.
|
|
||||||
if (static_cast<uint64_t>(client_list_pb.fl_id_size()) >= update_model_client_num) {
|
|
||||||
MS_LOG(INFO) << "send clients_list succeed!";
|
|
||||||
MS_LOG(INFO) << "UpdateModel client list: ";
|
|
||||||
for (size_t i = 0; i < client_list.size(); ++i) {
|
|
||||||
MS_LOG(INFO) << " fl_id : " << client_list[i];
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "update_model_client_num: " << update_model_client_num;
|
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
response = true;
|
|
||||||
} else {
|
|
||||||
MS_LOG(INFO) << "The server is not ready. update_model_client_need_num: " << update_model_client_num;
|
|
||||||
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
|
|
||||||
/*for (size_t i = 0; i < std::min(client_list.size(), size_t(2)); ++i) {
|
|
||||||
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
|
|
||||||
}
|
|
||||||
for (size_t i = client_list.size() - size_t(1); i > std::max(client_list.size() - size_t(2), size_t(0));
|
|
||||||
--i) {
|
|
||||||
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
|
|
||||||
}*/
|
|
||||||
int count_tmp = 0;
|
|
||||||
for (size_t i = 0; i < cipher_init_->get_model_num_need_; ++i) {
|
|
||||||
size_t j = 0;
|
|
||||||
for (; j < client_list.size(); ++j) {
|
|
||||||
if (("f" + std::to_string(i)) == client_list[j]) break;
|
|
||||||
}
|
|
||||||
if (j >= client_list.size()) {
|
|
||||||
count_tmp++;
|
|
||||||
MS_LOG(INFO) << " no client_list fl_id : " << i;
|
|
||||||
if (count_tmp > 3) break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", client_list,
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (response) {
|
|
||||||
DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "update_model_client_num is zero.";
|
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_num is zero.", client_list,
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return response;
|
uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
||||||
|
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||||
|
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
||||||
|
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
||||||
|
client_list.push_back(client_list_pb.fl_id(i));
|
||||||
|
}
|
||||||
|
if (static_cast<uint64_t>(client_list.size()) < update_model_client_needed) {
|
||||||
|
MS_LOG(INFO) << "The server is not ready. update_model_client_needed: " << update_model_client_needed;
|
||||||
|
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", empty_client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (find(client_list.begin(), client_list.end(), fl_id) == client_list.end()) { // client not in update model clients
|
||||||
|
std::string reason = "fl_id: " + fl_id + " is not in the update_model_clients";
|
||||||
|
MS_LOG(INFO) << reason;
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, empty_client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool retcode_client =
|
||||||
|
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxGetUpdateModelClientList, fl_id);
|
||||||
|
if (!retcode_client) {
|
||||||
|
std::string reason = "update get update model clients failed";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, reason, empty_client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str())) {
|
||||||
|
std::string reason = "Counting for get user list request failed. Please retry later.";
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "send clients_list succeed!";
|
||||||
|
MS_LOG(INFO) << "UpdateModel client list: ";
|
||||||
|
for (size_t i = 0; i < client_list.size(); ++i) {
|
||||||
|
MS_LOG(INFO) << " fl_id : " << client_list[i];
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "update_model_client_needed: " << update_model_client_needed;
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
|
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
|
|
||||||
std::vector<string> client_list;
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
if (inputs.size() != 1) {
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
MS_LOG(ERROR) << "ClientListKernel needs 1 input,but got " << inputs.size();
|
MS_LOG(ERROR) << reason;
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel input num not match", client_list,
|
return false;
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
} else if (outputs.size() != 1) {
|
|
||||||
MS_LOG(ERROR) << "ClientListKernel needs 1 output,but got " << outputs.size();
|
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel output num not match", client_list,
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
} else {
|
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
|
||||||
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
|
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "ClientListKernel num is enough", client_list,
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
} else {
|
|
||||||
void *req_data = inputs[0]->addr;
|
|
||||||
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
|
|
||||||
|
|
||||||
if (get_clients_req == nullptr || fbb == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "GetClientList is nullptr or ClientListRsp builder is nullptr.";
|
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_RequestError,
|
|
||||||
"GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list,
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
} else {
|
|
||||||
DealClient(iter_num, get_clients_req, fbb);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||||
|
void *req_data = inputs[0]->addr;
|
||||||
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<string> client_list;
|
||||||
|
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
|
||||||
|
int32_t iter_client = (size_t)get_clients_req->iteration();
|
||||||
|
if (iter_num != (size_t)iter_client) {
|
||||||
|
MS_LOG(ERROR) << "client list iteration number is invalid: server now iteration is " << iter_num
|
||||||
|
<< ". client request iteration is " << iter_client;
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
|
||||||
|
}
|
||||||
|
|
||||||
|
DealClient(iter_num, get_clients_req, fbb);
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
clock_t end_time = clock();
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
|
@ -164,30 +158,22 @@ void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> cli
|
||||||
const int iteration) {
|
const int iteration) {
|
||||||
auto rsp_reason = client_list_resp_builder->CreateString(reason);
|
auto rsp_reason = client_list_resp_builder->CreateString(reason);
|
||||||
auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time);
|
auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time);
|
||||||
if (clients.size() > 0) {
|
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
|
||||||
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
|
for (auto client : clients) {
|
||||||
for (auto client : clients) {
|
auto client_fb = client_list_resp_builder->CreateString(client);
|
||||||
auto client_fb = client_list_resp_builder->CreateString(client);
|
clients_vector.push_back(client_fb);
|
||||||
clients_vector.push_back(client_fb);
|
MS_LOG(WARNING) << "update client list: ";
|
||||||
}
|
MS_LOG(WARNING) << client;
|
||||||
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
|
|
||||||
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
|
|
||||||
rsp_builder.add_retcode(retcode);
|
|
||||||
rsp_builder.add_reason(rsp_reason);
|
|
||||||
rsp_builder.add_clients(clients_fb);
|
|
||||||
rsp_builder.add_iteration(iteration);
|
|
||||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
|
||||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
|
||||||
client_list_resp_builder->Finish(rsp_exchange_keys);
|
|
||||||
} else {
|
|
||||||
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
|
|
||||||
rsp_builder.add_retcode(retcode);
|
|
||||||
rsp_builder.add_reason(rsp_reason);
|
|
||||||
rsp_builder.add_iteration(iteration);
|
|
||||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
|
||||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
|
||||||
client_list_resp_builder->Finish(rsp_exchange_keys);
|
|
||||||
}
|
}
|
||||||
|
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
|
||||||
|
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
|
||||||
|
rsp_builder.add_retcode(retcode);
|
||||||
|
rsp_builder.add_reason(rsp_reason);
|
||||||
|
rsp_builder.add_clients(clients_fb);
|
||||||
|
rsp_builder.add_iteration(iteration);
|
||||||
|
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||||
|
auto rsp_exchange_keys = rsp_builder.Finish();
|
||||||
|
client_list_resp_builder->Finish(rsp_exchange_keys);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,9 +38,36 @@ void ExchangeKeysKernel::InitKernel(size_t) {
|
||||||
cipher_key_ = &armour::CipherKeys::GetInstance();
|
cipher_key_ = &armour::CipherKeys::GetInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num) {
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
std::string reason = "Current amount for exchangeKey is enough. Please retry later.";
|
||||||
|
cipher_key_->BuildExchangeKeysRsp(
|
||||||
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
|
||||||
|
MS_LOG(WARNING) << reason;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
|
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||||
|
const int iter_num) {
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(exchange_keys_req, false);
|
||||||
|
if (!DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str())) {
|
||||||
|
std::string reason = "Counting for exchange kernel request failed. Please retry later.";
|
||||||
|
cipher_key_->BuildExchangeKeysRsp(
|
||||||
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
MS_LOG(INFO) << "Launching ExchangeKey kernel.";
|
||||||
bool response = false;
|
bool response = false;
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
|
@ -48,46 +75,49 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
||||||
<< total_duration;
|
<< total_duration;
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
|
|
||||||
if (inputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 input,but got " << inputs.size();
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel input num not match",
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void *req_data = inputs[0]->addr;
|
||||||
|
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||||
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ReachThresholdForExchangeKeys(fbb, iter_num)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
|
||||||
|
int32_t iter_client = (size_t)exchange_keys_req->iteration();
|
||||||
|
if (iter_num != (size_t)iter_client) {
|
||||||
|
MS_LOG(ERROR) << "ExchangeKeys iteration number is invalid: server now iteration is " << iter_num
|
||||||
|
<< ". client request iteration is " << iter_client;
|
||||||
|
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
} else if (outputs.size() != 1) {
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 output,but got " << outputs.size();
|
return true;
|
||||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel output num not match",
|
}
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
|
||||||
} else {
|
if (!response) {
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
MS_LOG(WARNING) << "update exchange keys is failed.";
|
||||||
MS_LOG(ERROR) << "Current amount for ExchangeKeysKernel is enough.";
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime,
|
return true;
|
||||||
"Current amount for ExchangeKeysKernel is enough.",
|
}
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
if (!CountForExchangeKeys(fbb, exchange_keys_req, iter_num)) {
|
||||||
} else {
|
MS_LOG(ERROR) << "count for exchange keys failed.";
|
||||||
void *req_data = inputs[0]->addr;
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
const schema::RequestExchangeKeys *exchange_keys_req =
|
return true;
|
||||||
flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
|
|
||||||
int32_t iter_client = (size_t)exchange_keys_req->iteration();
|
|
||||||
if (iter_num != (size_t)iter_client) {
|
|
||||||
MS_LOG(ERROR) << "ExchangeKeysKernel iteration invalid. server now iteration is " << iter_num
|
|
||||||
<< ". client request iteration is " << iter_client;
|
|
||||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
} else {
|
|
||||||
response =
|
|
||||||
cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
|
|
||||||
if (response) {
|
|
||||||
DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
clock_t end_time = clock();
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
|
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
|
||||||
if (!response) {
|
|
||||||
MS_LOG(INFO) << "ExchangeKeysKernel response is false.";
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
@ -41,6 +43,9 @@ class ExchangeKeysKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_;
|
||||||
armour::CipherKeys *cipher_key_;
|
armour::CipherKeys *cipher_key_;
|
||||||
|
bool ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num);
|
||||||
|
bool CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestExchangeKeys *exchange_keys_req,
|
||||||
|
const int iter_num);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -37,9 +37,23 @@ void GetKeysKernel::InitKernel(size_t) {
|
||||||
cipher_key_ = &armour::CipherKeys::GetInstance();
|
cipher_key_ = &armour::CipherKeys::GetInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
|
||||||
|
const int iter_num) {
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(get_keys_req, false);
|
||||||
|
if (!DistributedCountService::GetInstance().Count(name_, get_keys_req->fl_id()->str())) {
|
||||||
|
std::string reason = "Counting for getkeys kernel request failed. Please retry later.";
|
||||||
|
cipher_key_->BuildGetKeysRsp(
|
||||||
|
fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), false);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
MS_LOG(INFO) << "Launching GetKeys kernel.";
|
||||||
bool response = false;
|
bool response = false;
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
|
@ -47,44 +61,48 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
||||||
<< total_duration;
|
<< total_duration;
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
|
|
||||||
if (inputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
MS_LOG(ERROR) << "GetKeysKernel needs 1 input,but got " << inputs.size();
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
|
MS_LOG(ERROR) << reason;
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
return false;
|
||||||
} else if (outputs.size() != 1) {
|
}
|
||||||
MS_LOG(ERROR) << "GetKeysKernel needs 1 output,but got " << outputs.size();
|
|
||||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
|
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
void *req_data = inputs[0]->addr;
|
||||||
} else {
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||||
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
|
MS_LOG(ERROR) << reason;
|
||||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
return false;
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
}
|
||||||
} else {
|
|
||||||
void *req_data = inputs[0]->addr;
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
|
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
|
||||||
int32_t iter_client = (size_t)get_exchange_keys_req->iteration();
|
}
|
||||||
if (iter_num != (size_t)iter_client) {
|
|
||||||
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
|
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
|
||||||
<< ". client request iteration is " << iter_client;
|
int32_t iter_client = (size_t)get_exchange_keys_req->iteration();
|
||||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
if (iter_num != (size_t)iter_client) {
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
|
||||||
} else {
|
<< ". client request iteration is " << iter_client;
|
||||||
response =
|
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||||
cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
|
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||||
if (response) {
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
DistributedCountService::GetInstance().Count(name_, get_exchange_keys_req->fl_id()->str());
|
return true;
|
||||||
}
|
}
|
||||||
}
|
response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
|
||||||
}
|
if (!response) {
|
||||||
|
MS_LOG(WARNING) << "get public keys is failed.";
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!CountForGetKeys(fbb, get_exchange_keys_req, iter_num)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
clock_t end_time = clock();
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
|
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
|
||||||
if (!response) {
|
|
||||||
MS_LOG(INFO) << "GetKeysKernel response is false.";
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
@ -41,6 +43,8 @@ class GetKeysKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_;
|
||||||
armour::CipherKeys *cipher_key_;
|
armour::CipherKeys *cipher_key_;
|
||||||
|
bool CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
|
||||||
|
const int iter_num);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -39,54 +39,72 @@ void GetSecretsKernel::InitKernel(size_t) {
|
||||||
cipher_share_ = &armour::CipherShares::GetInstance();
|
cipher_share_ = &armour::CipherShares::GetInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
|
const schema::GetShareSecrets *get_secrets_req, const int iter_num) {
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false);
|
||||||
|
if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) {
|
||||||
|
std::string reason = "Counting for get secrets kernel request failed. Please retry later.";
|
||||||
|
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), nullptr);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
bool response = false;
|
bool response = false;
|
||||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
|
||||||
|
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
|
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
|
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
|
||||||
<< total_duration;
|
<< total_duration;
|
||||||
|
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
|
|
||||||
if (inputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
MS_LOG(ERROR) << "GetSecretsKernel needs 1 input,but got " << inputs.size();
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
|
MS_LOG(ERROR) << reason;
|
||||||
} else if (outputs.size() != 1) {
|
return false;
|
||||||
MS_LOG(ERROR) << "GetSecretsKernel needs 1 output,but got " << outputs.size();
|
|
||||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
|
|
||||||
} else {
|
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
|
||||||
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
|
|
||||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
|
|
||||||
} else {
|
|
||||||
void *req_data = inputs[0]->addr;
|
|
||||||
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
|
|
||||||
int32_t iter_client = (size_t)get_secrets_req->iteration();
|
|
||||||
if (iter_num != (size_t)iter_client) {
|
|
||||||
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
|
|
||||||
<< ". client request iteration is " << iter_client;
|
|
||||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
|
|
||||||
} else {
|
|
||||||
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
|
|
||||||
if (response) {
|
|
||||||
DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||||
|
void *req_data = inputs[0]->addr;
|
||||||
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
|
||||||
|
int32_t iter_client = (size_t)get_secrets_req->iteration();
|
||||||
|
if (iter_num != (size_t)iter_client) {
|
||||||
|
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||||
|
<< ". client request iteration is " << iter_client;
|
||||||
|
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
|
||||||
|
}
|
||||||
|
|
||||||
|
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
|
||||||
|
if (!response) {
|
||||||
|
MS_LOG(WARNING) << "get secret shares is failed.";
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!CountForGetSecrets(fbb, get_secrets_req, iter_num)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
clock_t end_time = clock();
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
|
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
|
||||||
if (!response) {
|
|
||||||
MS_LOG(INFO) << "GetSecretsKernel response is false.";
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
@ -41,6 +42,8 @@ class GetSecretsKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_;
|
||||||
armour::CipherShares *cipher_share_;
|
armour::CipherShares *cipher_share_;
|
||||||
|
bool CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::GetShareSecrets *get_secrets_req,
|
||||||
|
const int iter_num);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -52,7 +52,6 @@ void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
|
||||||
|
|
||||||
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
|
||||||
bool response = false;
|
bool response = false;
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
// MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
// MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
|
@ -61,71 +60,58 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
||||||
<< total_duration;
|
<< total_duration;
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
|
|
||||||
if (inputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
|
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
|
||||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
return false;
|
||||||
"ReconstructSecretsKernel input num not match.", iter_num,
|
}
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
|
||||||
} else if (outputs.size() != 1) {
|
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||||
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 output, but got " << outputs.size();
|
void *req_data = inputs[0]->addr;
|
||||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
|
||||||
"ReconstructSecretsKernel output num not match.", iter_num,
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||||
} else {
|
MS_LOG(ERROR) << reason;
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
return false;
|
||||||
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
|
}
|
||||||
|
|
||||||
|
// get client list from memory server.
|
||||||
|
std::vector<string> update_model_clients;
|
||||||
|
const PBMetadata update_model_clients_pb_out =
|
||||||
|
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||||
|
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
|
||||||
|
|
||||||
|
for (int i = 0; i < update_model_clients_pb.fl_id_size(); ++i) {
|
||||||
|
update_model_clients.push_back(update_model_clients_pb.fl_id(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
const schema::SendReconstructSecret *reconstruct_secret_req =
|
||||||
|
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
|
||||||
|
std::string fl_id = reconstruct_secret_req->fl_id()->str();
|
||||||
|
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
|
||||||
|
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
|
||||||
|
// client in get update model client list.
|
||||||
|
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SUCCEED,
|
||||||
|
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||||
|
} else {
|
||||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||||
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
|
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void *req_data = inputs[0]->addr;
|
|
||||||
const schema::SendReconstructSecret *reconstruct_secret_req =
|
|
||||||
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
|
|
||||||
|
|
||||||
// get client list from memory server.
|
|
||||||
std::vector<string> client_list;
|
|
||||||
uint64_t update_model_client_num = 0;
|
|
||||||
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
|
||||||
update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "update_model_client_num is zero.";
|
|
||||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
|
||||||
"update_model_client_num is zero.", iter_num,
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
const PBMetadata client_list_pb_out =
|
|
||||||
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
|
||||||
|
|
||||||
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
|
||||||
int client_list_actual_size = client_list_pb.fl_id_size();
|
|
||||||
if (client_list_actual_size < 0) {
|
|
||||||
client_list_actual_size = 0;
|
|
||||||
}
|
|
||||||
if (static_cast<uint64_t>(client_list_actual_size) < update_model_client_num) {
|
|
||||||
MS_LOG(INFO) << "ReconstructSecretsKernel : client list is not ready " << inputs.size();
|
|
||||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SucNotReady,
|
|
||||||
"ReconstructSecretsKernel : client list is not ready", iter_num,
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
|
||||||
client_list.push_back(client_list_pb.fl_id(i));
|
|
||||||
}
|
|
||||||
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
|
|
||||||
reconstruct_secret_req, fbb, client_list);
|
|
||||||
if (response) {
|
|
||||||
// MS_LOG(INFO) << "start ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
|
|
||||||
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
|
|
||||||
// MS_LOG(INFO) << "end ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
|
|
||||||
}
|
}
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||||
|
reconstruct_secret_req, fbb, update_model_clients);
|
||||||
|
if (response) {
|
||||||
|
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
|
||||||
|
}
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough.";
|
||||||
|
}
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
clock_t end_time = clock();
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
|
@ -142,7 +128,6 @@ void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<ps::core::
|
||||||
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
|
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "start unmask";
|
MS_LOG(INFO) << "start unmask";
|
||||||
while (!Executor::GetInstance().Unmask()) {
|
while (!Executor::GetInstance().Unmask()) {
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||||
|
|
|
@ -36,58 +36,75 @@ void ShareSecretsKernel::InitKernel(size_t) {
|
||||||
cipher_share_ = &armour::CipherShares::GetInstance();
|
cipher_share_ = &armour::CipherShares::GetInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
|
const schema::RequestShareSecrets *share_secrets_req,
|
||||||
|
const int iter_num) {
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(share_secrets_req, false);
|
||||||
|
if (!DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str())) {
|
||||||
|
std::string reason = "Counting for share secret kernel request failed. Please retry later.";
|
||||||
|
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
bool response = false;
|
bool response = false;
|
||||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is "
|
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is "
|
||||||
<< total_duration;
|
<< total_duration;
|
||||||
clock_t start_time = clock();
|
clock_t start_time = clock();
|
||||||
|
|
||||||
if (inputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 input,but got " << inputs.size();
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError, "ShareSecretsKernel input num not match",
|
MS_LOG(ERROR) << reason;
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
return false;
|
||||||
} else if (outputs.size() != 1) {
|
|
||||||
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 output,but got " << outputs.size();
|
|
||||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
|
||||||
"ShareSecretsKernel output num not match",
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
} else {
|
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
|
||||||
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
|
|
||||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
|
||||||
"Current amount for ShareSecretsKernel is enough.",
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
} else {
|
|
||||||
void *req_data = inputs[0]->addr;
|
|
||||||
const schema::RequestShareSecrets *share_secrets_req =
|
|
||||||
flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
|
|
||||||
size_t iter_client = (size_t)share_secrets_req->iteration();
|
|
||||||
if (iter_num != iter_client) {
|
|
||||||
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
|
|
||||||
<< ". client request iteration is " << iter_client;
|
|
||||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
|
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
|
||||||
} else {
|
|
||||||
response =
|
|
||||||
cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
|
|
||||||
if (response) {
|
|
||||||
DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||||
|
void *req_data = inputs[0]->addr;
|
||||||
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
|
||||||
|
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||||
|
"Current amount for ShareSecretsKernel is enough.",
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
|
||||||
|
size_t iter_client = (size_t)share_secrets_req->iteration();
|
||||||
|
if (iter_num != iter_client) {
|
||||||
|
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||||
|
<< ". client request iteration is " << iter_client;
|
||||||
|
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
response = cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
|
||||||
|
if (!response) {
|
||||||
|
MS_LOG(WARNING) << "update secret shares is failed.";
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!CountForShareSecrets(fbb, share_secrets_req, iter_num)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
clock_t end_time = clock();
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||||
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
|
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
|
||||||
if (!response) {
|
|
||||||
MS_LOG(INFO) << "share_secrets_kernel response is false.";
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
|
@ -41,6 +43,8 @@ class ShareSecretsKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_;
|
||||||
armour::CipherShares *cipher_share_;
|
armour::CipherShares *cipher_share_;
|
||||||
|
bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req,
|
||||||
|
const int iter_num);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -157,14 +157,31 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
|
||||||
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
|
PBMetadata device_metas = DistributedMetadataStore::GetInstance().GetMetadata(kCtxDeviceMetas);
|
||||||
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
|
FLIdToDeviceMeta fl_id_to_meta = device_metas.device_metas();
|
||||||
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
std::string update_model_fl_id = update_model_req->fl_id()->str();
|
||||||
MS_LOG(INFO) << "Update model for fl id " << update_model_fl_id;
|
MS_LOG(INFO) << "UpdateModel for fl id " << update_model_fl_id;
|
||||||
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
if (ps::PSContext::instance()->encrypt_type() != ps::kPWEncryptType) {
|
||||||
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
|
if (fl_id_to_meta.fl_id_to_meta().count(update_model_fl_id) == 0) {
|
||||||
BuildUpdateModelRsp(
|
std::string reason = "devices_meta for " + update_model_fl_id + " is not set. Please retry later.";
|
||||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
BuildUpdateModelRsp(
|
||||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
MS_LOG(ERROR) << reason;
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
return ResultCode::kSuccessAndReturn;
|
MS_LOG(ERROR) << reason;
|
||||||
|
return ResultCode::kSuccessAndReturn;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::vector<std::string> get_secrets_clients;
|
||||||
|
#ifdef ENABLE_ARMOUR
|
||||||
|
mindspore::armour::CipherMetaStorage cipher_meta_storage;
|
||||||
|
cipher_meta_storage.GetClientListFromServer(fl::server::kCtxGetSecretsClientList, &get_secrets_clients);
|
||||||
|
#endif
|
||||||
|
if (find(get_secrets_clients.begin(), get_secrets_clients.end(), update_model_fl_id) ==
|
||||||
|
get_secrets_clients.end()) { // the client not in get_secrets_clients
|
||||||
|
std::string reason = "fl_id: " + update_model_fl_id + " is not in get_secrets_clients. Please retry later.";
|
||||||
|
BuildUpdateModelRsp(
|
||||||
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return ResultCode::kSuccessAndReturn;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
|
size_t data_size = fl_id_to_meta.fl_id_to_meta().at(update_model_fl_id).data_size();
|
||||||
|
|
|
@ -25,6 +25,9 @@
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
|
#ifdef ENABLE_ARMOUR
|
||||||
|
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
|
|
@ -213,53 +213,24 @@ void Server::InitIteration() {
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||||
if (encrypt_type == ps::kPWEncryptType) {
|
if (encrypt_type == ps::kPWEncryptType) {
|
||||||
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
|
cipher_exchange_keys_cnt_ = cipher_config_.exchange_keys_threshold;
|
||||||
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
|
cipher_get_keys_cnt_ = cipher_config_.get_keys_threshold;
|
||||||
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
|
cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
|
||||||
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
|
cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
|
||||||
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
|
cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
|
||||||
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold;
|
cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
|
||||||
|
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
|
||||||
cipher_time_window_ = cipher_config_.cipher_time_window;
|
cipher_time_window_ = cipher_config_.cipher_time_window;
|
||||||
|
|
||||||
MS_LOG(INFO) << "Initializing cipher:";
|
MS_LOG(INFO) << "Initializing cipher:";
|
||||||
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
|
MS_LOG(INFO) << " cipher_exchange_keys_cnt_: " << cipher_exchange_keys_cnt_
|
||||||
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
|
<< " cipher_get_keys_cnt_: " << cipher_get_keys_cnt_
|
||||||
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
|
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
|
||||||
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
|
MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
|
||||||
|
<< " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
|
||||||
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
|
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
|
||||||
<< " cipher_time_window_: " << cipher_time_window_
|
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
|
||||||
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
|
<< " cipher_time_window_: " << cipher_time_window_;
|
||||||
|
|
||||||
std::shared_ptr<Round> exchange_keys_round =
|
|
||||||
std::make_shared<Round>("exchangeKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
|
|
||||||
MS_EXCEPTION_IF_NULL(exchange_keys_round);
|
|
||||||
iteration_->AddRound(exchange_keys_round);
|
|
||||||
|
|
||||||
std::shared_ptr<Round> get_keys_round =
|
|
||||||
std::make_shared<Round>("getKeys", true, cipher_time_window_, true, cipher_exchange_secrets_cnt_);
|
|
||||||
MS_EXCEPTION_IF_NULL(get_keys_round);
|
|
||||||
iteration_->AddRound(get_keys_round);
|
|
||||||
|
|
||||||
std::shared_ptr<Round> share_secrets_round =
|
|
||||||
std::make_shared<Round>("shareSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
|
|
||||||
MS_EXCEPTION_IF_NULL(share_secrets_round);
|
|
||||||
iteration_->AddRound(share_secrets_round);
|
|
||||||
|
|
||||||
std::shared_ptr<Round> get_secrets_round =
|
|
||||||
std::make_shared<Round>("getSecrets", true, cipher_time_window_, true, cipher_share_secrets_cnt_);
|
|
||||||
MS_EXCEPTION_IF_NULL(get_secrets_round);
|
|
||||||
iteration_->AddRound(get_secrets_round);
|
|
||||||
|
|
||||||
std::shared_ptr<Round> get_clientlist_round =
|
|
||||||
std::make_shared<Round>("getClientList", true, cipher_time_window_, true, cipher_get_clientlist_cnt_);
|
|
||||||
MS_EXCEPTION_IF_NULL(get_clientlist_round);
|
|
||||||
iteration_->AddRound(get_clientlist_round);
|
|
||||||
|
|
||||||
std::shared_ptr<Round> reconstruct_secrets_round = std::make_shared<Round>(
|
|
||||||
"reconstructSecrets", true, cipher_time_window_, true, cipher_reconstruct_secrets_up_cnt_);
|
|
||||||
MS_EXCEPTION_IF_NULL(reconstruct_secrets_round);
|
|
||||||
iteration_->AddRound(reconstruct_secrets_round);
|
|
||||||
MS_LOG(INFO) << "Cipher rounds has been added.";
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -314,8 +285,8 @@ void Server::InitCipher() {
|
||||||
param.dp_eps = dp_eps;
|
param.dp_eps = dp_eps;
|
||||||
param.dp_norm_clip = dp_norm_clip;
|
param.dp_norm_clip = dp_norm_clip;
|
||||||
param.encrypt_type = encrypt_type;
|
param.encrypt_type = encrypt_type;
|
||||||
cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_,
|
cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
|
||||||
cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
|
cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
|
||||||
cipher_reconstruct_secrets_up_cnt_);
|
cipher_reconstruct_secrets_up_cnt_);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,7 +80,7 @@ class Server {
|
||||||
worker_num_(0),
|
worker_num_(0),
|
||||||
fl_server_port_(0),
|
fl_server_port_(0),
|
||||||
cipher_initial_client_cnt_(0),
|
cipher_initial_client_cnt_(0),
|
||||||
cipher_exchange_secrets_cnt_(0),
|
cipher_exchange_keys_cnt_(0),
|
||||||
cipher_share_secrets_cnt_(0),
|
cipher_share_secrets_cnt_(0),
|
||||||
cipher_get_clientlist_cnt_(0),
|
cipher_get_clientlist_cnt_(0),
|
||||||
cipher_reconstruct_secrets_up_cnt_(0),
|
cipher_reconstruct_secrets_up_cnt_(0),
|
||||||
|
@ -197,8 +197,10 @@ class Server {
|
||||||
uint32_t worker_num_;
|
uint32_t worker_num_;
|
||||||
uint16_t fl_server_port_;
|
uint16_t fl_server_port_;
|
||||||
size_t cipher_initial_client_cnt_;
|
size_t cipher_initial_client_cnt_;
|
||||||
size_t cipher_exchange_secrets_cnt_;
|
size_t cipher_exchange_keys_cnt_;
|
||||||
|
size_t cipher_get_keys_cnt_;
|
||||||
size_t cipher_share_secrets_cnt_;
|
size_t cipher_share_secrets_cnt_;
|
||||||
|
size_t cipher_get_secrets_cnt_;
|
||||||
size_t cipher_get_clientlist_cnt_;
|
size_t cipher_get_clientlist_cnt_;
|
||||||
size_t cipher_reconstruct_secrets_up_cnt_;
|
size_t cipher_reconstruct_secrets_up_cnt_;
|
||||||
size_t cipher_reconstruct_secrets_down_cnt_;
|
size_t cipher_reconstruct_secrets_down_cnt_;
|
||||||
|
|
|
@ -856,9 +856,33 @@ bool StartServerAction(const ResourcePtr &res) {
|
||||||
|
|
||||||
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
|
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
|
||||||
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
|
uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
|
||||||
size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold();
|
size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold() + 1;
|
||||||
|
|
||||||
fl::server::CipherConfig cipher_config = {share_secrets_ratio, cipher_time_window, reconstruct_secrets_threshold};
|
size_t exchange_keys_threshold =
|
||||||
|
std::max(static_cast<size_t>(std::ceil(start_fl_job_threshold * share_secrets_ratio)), update_model_threshold);
|
||||||
|
size_t get_keys_threshold =
|
||||||
|
std::max(static_cast<size_t>(std::ceil(exchange_keys_threshold * share_secrets_ratio)), update_model_threshold);
|
||||||
|
size_t share_secrets_threshold =
|
||||||
|
std::max(static_cast<size_t>(std::ceil(get_keys_threshold * share_secrets_ratio)), update_model_threshold);
|
||||||
|
size_t get_secrets_threshold =
|
||||||
|
std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold);
|
||||||
|
size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)),
|
||||||
|
reconstruct_secrets_threshold);
|
||||||
|
#ifdef ENABLE_ARMOUR
|
||||||
|
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||||
|
if (encrypt_type == ps::kPWEncryptType) {
|
||||||
|
MS_LOG(INFO) << "Add secure aggregation rounds.";
|
||||||
|
rounds_config.push_back({"exchangeKeys", true, cipher_time_window, true, exchange_keys_threshold});
|
||||||
|
rounds_config.push_back({"getKeys", true, cipher_time_window, true, get_keys_threshold});
|
||||||
|
rounds_config.push_back({"shareSecrets", true, cipher_time_window, true, share_secrets_threshold});
|
||||||
|
rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold});
|
||||||
|
rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
|
||||||
|
rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
fl::server::CipherConfig cipher_config = {
|
||||||
|
share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,
|
||||||
|
share_secrets_threshold, get_secrets_threshold, client_list_threshold, reconstruct_secrets_threshold};
|
||||||
|
|
||||||
size_t executor_threshold = 0;
|
size_t executor_threshold = 0;
|
||||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||||
|
|
|
@ -126,6 +126,13 @@ message SharesPb {
|
||||||
|
|
||||||
message KeysPb {
|
message KeysPb {
|
||||||
repeated bytes key = 1;
|
repeated bytes key = 1;
|
||||||
|
string timestamp = 2;
|
||||||
|
int32 iter_num = 3;
|
||||||
|
bytes ind_iv = 4;
|
||||||
|
bytes pw_iv = 5;
|
||||||
|
bytes pw_salt = 6;
|
||||||
|
bytes signature = 7;
|
||||||
|
repeated string certificate_chain = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Prime {
|
message Prime {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,12 +16,18 @@
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.I_VEC_LEN;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.SALT_SIZE;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
import com.mindspore.flclient.cipher.AESEncrypt;
|
import com.mindspore.flclient.cipher.AESEncrypt;
|
||||||
import com.mindspore.flclient.cipher.BaseUtil;
|
import com.mindspore.flclient.cipher.BaseUtil;
|
||||||
import com.mindspore.flclient.cipher.ClientListReq;
|
import com.mindspore.flclient.cipher.ClientListReq;
|
||||||
import com.mindspore.flclient.cipher.KEYAgreement;
|
import com.mindspore.flclient.cipher.KEYAgreement;
|
||||||
import com.mindspore.flclient.cipher.Random;
|
import com.mindspore.flclient.cipher.Masking;
|
||||||
import com.mindspore.flclient.cipher.ReconstructSecretReq;
|
import com.mindspore.flclient.cipher.ReconstructSecretReq;
|
||||||
import com.mindspore.flclient.cipher.ShareSecrets;
|
import com.mindspore.flclient.cipher.ShareSecrets;
|
||||||
import com.mindspore.flclient.cipher.struct.ClientPublicKey;
|
import com.mindspore.flclient.cipher.struct.ClientPublicKey;
|
||||||
|
@ -29,6 +35,7 @@ import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
||||||
import com.mindspore.flclient.cipher.struct.EncryptShare;
|
import com.mindspore.flclient.cipher.struct.EncryptShare;
|
||||||
import com.mindspore.flclient.cipher.struct.NewArray;
|
import com.mindspore.flclient.cipher.struct.NewArray;
|
||||||
import com.mindspore.flclient.cipher.struct.ShareSecret;
|
import com.mindspore.flclient.cipher.struct.ShareSecret;
|
||||||
|
|
||||||
import mindspore.schema.ClientShare;
|
import mindspore.schema.ClientShare;
|
||||||
import mindspore.schema.GetExchangeKeys;
|
import mindspore.schema.GetExchangeKeys;
|
||||||
import mindspore.schema.GetShareSecrets;
|
import mindspore.schema.GetShareSecrets;
|
||||||
|
@ -40,22 +47,22 @@ import mindspore.schema.ResponseShareSecrets;
|
||||||
import mindspore.schema.ReturnExchangeKeys;
|
import mindspore.schema.ReturnExchangeKeys;
|
||||||
import mindspore.schema.ReturnShareSecrets;
|
import mindspore.schema.ReturnShareSecrets;
|
||||||
|
|
||||||
import java.io.UnsupportedEncodingException;
|
import java.io.IOException;
|
||||||
import java.math.BigInteger;
|
import java.math.BigInteger;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.security.NoSuchAlgorithmException;
|
import java.security.SecureRandom;
|
||||||
import java.security.spec.InvalidKeySpecException;
|
|
||||||
import java.time.LocalDateTime;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Date;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
/**
|
||||||
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
|
* A class used for secure aggregation
|
||||||
import static com.mindspore.flclient.LocalFLParameter.SEED_SIZE;
|
*
|
||||||
|
* @since 2021-8-27
|
||||||
|
*/
|
||||||
public class CipherClient {
|
public class CipherClient {
|
||||||
private static final Logger LOGGER = Logger.getLogger(CipherClient.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(CipherClient.class.toString());
|
||||||
private FLCommunication flCommunication;
|
private FLCommunication flCommunication;
|
||||||
|
@ -63,129 +70,217 @@ public class CipherClient {
|
||||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||||
private final int iteration;
|
private final int iteration;
|
||||||
private int featureSize;
|
private int featureSize;
|
||||||
private int t;
|
private int minShareNum;
|
||||||
private List<byte[]> cKey = new ArrayList<>();
|
private List<byte[]> cKey = new ArrayList<>();
|
||||||
private List<byte[]> sKey = new ArrayList<>();
|
private List<byte[]> sKey = new ArrayList<>();
|
||||||
private byte[] bu;
|
private byte[] bu;
|
||||||
|
private byte[] individualIv = new byte[I_VEC_LEN];
|
||||||
|
private byte[] pwIVec = new byte[I_VEC_LEN];
|
||||||
|
private byte[] pwSalt = new byte[SALT_SIZE];
|
||||||
private String nextRequestTime;
|
private String nextRequestTime;
|
||||||
private Map<String, ClientPublicKey> clientPublicKeyList = new HashMap<String, ClientPublicKey>();
|
private Map<String, ClientPublicKey> clientPublicKeyList = new HashMap<String, ClientPublicKey>();
|
||||||
private Map<String, byte[]> sUVKeys = new HashMap<String, byte[]>();
|
private Map<String, byte[]> sUVKeys = new HashMap<String, byte[]>();
|
||||||
private Map<String, byte[]> cUVKeys = new HashMap<String, byte[]>();
|
private Map<String, byte[]> cUVKeys = new HashMap<String, byte[]>();
|
||||||
private List<EncryptShare> clientShareList = new ArrayList<>();
|
private List<EncryptShare> clientShareList = new ArrayList<>();
|
||||||
private List<EncryptShare> returnShareList = new ArrayList<>();
|
private List<EncryptShare> returnShareList = new ArrayList<>();
|
||||||
private float[] featureMask;
|
|
||||||
private List<String> u1ClientList = new ArrayList<>();
|
private List<String> u1ClientList = new ArrayList<>();
|
||||||
private List<String> u2UClientList = new ArrayList<>();
|
private List<String> u2UClientList = new ArrayList<>();
|
||||||
private List<String> u3ClientList = new ArrayList<>();
|
private List<String> u3ClientList = new ArrayList<>();
|
||||||
private List<DecryptShareSecrets> decryptShareSecretsList = new ArrayList<>();
|
private List<DecryptShareSecrets> decryptShareSecretsList = new ArrayList<>();
|
||||||
private byte[] prime;
|
private byte[] prime;
|
||||||
private KEYAgreement keyAgreement = new KEYAgreement();
|
private KEYAgreement keyAgreement = new KEYAgreement();
|
||||||
private Random random = new Random();
|
private Masking masking = new Masking();
|
||||||
private ClientListReq clientListReq = new ClientListReq();
|
private ClientListReq clientListReq = new ClientListReq();
|
||||||
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
|
private ReconstructSecretReq reconstructSecretReq = new ReconstructSecretReq();
|
||||||
private int retCode;
|
private int retCode;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct function of cipherClient
|
||||||
|
*
|
||||||
|
* @param iter iteration number
|
||||||
|
* @param minSecretNum minimum secret shares number used for reconstruct secret
|
||||||
|
* @param prime prime value
|
||||||
|
* @param featureSize featureSize of network
|
||||||
|
*/
|
||||||
public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) {
|
public CipherClient(int iter, int minSecretNum, byte[] prime, int featureSize) {
|
||||||
flCommunication = FLCommunication.getInstance();
|
flCommunication = FLCommunication.getInstance();
|
||||||
this.iteration = iter;
|
this.iteration = iter;
|
||||||
this.featureSize = featureSize;
|
this.featureSize = featureSize;
|
||||||
this.t = minSecretNum;
|
this.minShareNum = minSecretNum;
|
||||||
this.prime = prime;
|
this.prime = prime;
|
||||||
this.featureMask = new float[this.featureSize];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set next request time
|
||||||
|
*
|
||||||
|
* @param nextRequestTime next request timestamp
|
||||||
|
*/
|
||||||
public void setNextRequestTime(String nextRequestTime) {
|
public void setNextRequestTime(String nextRequestTime) {
|
||||||
this.nextRequestTime = nextRequestTime;
|
this.nextRequestTime = nextRequestTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setBU(byte[] bu) {
|
/**
|
||||||
this.bu = bu;
|
* Set client share list
|
||||||
}
|
*
|
||||||
|
* @param clientShareList client share list
|
||||||
public void setClientShareList(List<EncryptShare> clientShareList) {
|
*/
|
||||||
|
private void setClientShareList(List<EncryptShare> clientShareList) {
|
||||||
this.clientShareList.clear();
|
this.clientShareList.clear();
|
||||||
this.clientShareList = clientShareList;
|
this.clientShareList = clientShareList;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get next request time
|
||||||
|
*
|
||||||
|
* @return next request time
|
||||||
|
*/
|
||||||
public String getNextRequestTime() {
|
public String getNextRequestTime() {
|
||||||
return nextRequestTime;
|
return nextRequestTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get retCode
|
||||||
|
*
|
||||||
|
* @return retCode
|
||||||
|
*/
|
||||||
public int getRetCode() {
|
public int getRetCode() {
|
||||||
return retCode;
|
return retCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void genDHKeyPairs() {
|
private FLClientStatus genDHKeyPairs() {
|
||||||
byte[] csk = keyAgreement.generatePrivateKey();
|
byte[] csk = keyAgreement.generatePrivateKey();
|
||||||
byte[] cpk = keyAgreement.generatePublicKey(csk);
|
byte[] cpk = keyAgreement.generatePublicKey(csk);
|
||||||
|
if (cpk == null || cpk.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[genDHKeyPairs] the return byte[] <cpk> is null, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
byte[] ssk = keyAgreement.generatePrivateKey();
|
byte[] ssk = keyAgreement.generatePrivateKey();
|
||||||
byte[] spk = keyAgreement.generatePublicKey(ssk);
|
byte[] spk = keyAgreement.generatePublicKey(ssk);
|
||||||
|
if (spk == null || spk.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[genDHKeyPairs] the return byte[] <spk> is null, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
this.cKey.clear();
|
||||||
|
this.sKey.clear();
|
||||||
this.cKey.add(cpk);
|
this.cKey.add(cpk);
|
||||||
this.cKey.add(csk);
|
this.cKey.add(csk);
|
||||||
this.sKey.add(spk);
|
this.sKey.add(spk);
|
||||||
this.sKey.add(ssk);
|
this.sKey.add(ssk);
|
||||||
|
return FLClientStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void genIndividualSecret() {
|
private FLClientStatus genIndividualSecret() {
|
||||||
byte[] key = new byte[SEED_SIZE];
|
byte[] key = new byte[SEED_SIZE];
|
||||||
random.getRandomBytes(key);
|
int tag = masking.getRandomBytes(key);
|
||||||
setBU(key);
|
if (tag == -1) {
|
||||||
|
LOGGER.severe(Common.addTag("[genIndividualSecret] the return value is -1, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
this.bu = key;
|
||||||
|
return FLClientStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<ShareSecret> genSecretShares(byte[] secret) throws UnsupportedEncodingException {
|
private List<ShareSecret> genSecretShares(byte[] secret) {
|
||||||
List<ShareSecret> shareSecretList = new ArrayList<>();
|
if (secret == null || secret.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[genSecretShares] the input argument <secret> is null"));
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
int size = u1ClientList.size();
|
int size = u1ClientList.size();
|
||||||
ShareSecrets shamir = new ShareSecrets(t, size - 1);
|
if (size <= 1) {
|
||||||
ShareSecrets.SecretShare[] shares = shamir.split(secret, prime);
|
LOGGER.severe(Common.addTag("[genSecretShares] the size of u1ClientList is not valid: <= 1, it should be " +
|
||||||
int j = 0;
|
"> 1"));
|
||||||
for (int i = 0; i < size; i++) {
|
return new ArrayList<>();
|
||||||
String vFlID = u1ClientList.get(i);
|
}
|
||||||
|
ShareSecrets shamir = new ShareSecrets(minShareNum, size - 1);
|
||||||
|
ShareSecrets.SecretShares[] shares = shamir.split(secret, prime);
|
||||||
|
if (shares == null || shares.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[genSecretShares] the return ShareSecrets.SecretShare[] is null, please " +
|
||||||
|
"check!"));
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
int shareIndex = 0;
|
||||||
|
List<ShareSecret> shareSecretList = new ArrayList<>();
|
||||||
|
for (String vFlID : u1ClientList) {
|
||||||
if (localFLParameter.getFlID().equals(vFlID)) {
|
if (localFLParameter.getFlID().equals(vFlID)) {
|
||||||
continue;
|
continue;
|
||||||
} else {
|
|
||||||
ShareSecret shareSecret = new ShareSecret();
|
|
||||||
NewArray<byte[]> array = new NewArray<>();
|
|
||||||
int index = shares[j].getNum();
|
|
||||||
BigInteger intShare = shares[j].getShare();
|
|
||||||
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
|
|
||||||
array.setSize(share.length);
|
|
||||||
array.setArray(share);
|
|
||||||
shareSecret.setFlID(vFlID);
|
|
||||||
shareSecret.setShare(array);
|
|
||||||
shareSecret.setIndex(index);
|
|
||||||
shareSecretList.add(shareSecret);
|
|
||||||
j += 1;
|
|
||||||
}
|
}
|
||||||
|
if (shareIndex >= shares.length) {
|
||||||
|
LOGGER.severe(Common.addTag("[genSecretShares] the shareIndex is out of range in array <shares>, " +
|
||||||
|
"please check!"));
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
int index = shares[shareIndex].getNumber();
|
||||||
|
BigInteger intShare = shares[shareIndex].getShares();
|
||||||
|
byte[] share = BaseUtil.bigInteger2byteArray(intShare);
|
||||||
|
NewArray<byte[]> array = new NewArray<>();
|
||||||
|
array.setSize(share.length);
|
||||||
|
array.setArray(share);
|
||||||
|
ShareSecret shareSecret = new ShareSecret();
|
||||||
|
shareSecret.setFlID(vFlID);
|
||||||
|
shareSecret.setShare(array);
|
||||||
|
shareSecret.setIndex(index);
|
||||||
|
shareSecretList.add(shareSecret);
|
||||||
|
shareIndex += 1;
|
||||||
}
|
}
|
||||||
return shareSecretList;
|
return shareSecretList;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void genEncryptExchangedKeys() throws InvalidKeySpecException, NoSuchAlgorithmException {
|
private FLClientStatus genEncryptExchangedKeys() {
|
||||||
cUVKeys.clear();
|
cUVKeys.clear();
|
||||||
for (String key : clientPublicKeyList.keySet()) {
|
for (String key : clientPublicKeyList.keySet()) {
|
||||||
ClientPublicKey curPublicKey = clientPublicKeyList.get(key);
|
ClientPublicKey curPublicKey = clientPublicKeyList.get(key);
|
||||||
String vFlID = curPublicKey.getFlID();
|
String vFlID = curPublicKey.getFlID();
|
||||||
if (localFLParameter.getFlID().equals(vFlID)) {
|
if (localFLParameter.getFlID().equals(vFlID)) {
|
||||||
continue;
|
continue;
|
||||||
} else {
|
|
||||||
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
|
|
||||||
byte[] salt = new byte[0];
|
|
||||||
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
|
||||||
cUVKeys.put(vFlID, secret);
|
|
||||||
}
|
}
|
||||||
|
if (cKey.size() < 2) {
|
||||||
|
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the size of cKey is not valid: < 2, it should " +
|
||||||
|
"be >= 2, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
byte[] secret1 = keyAgreement.keyAgreement(cKey.get(1), curPublicKey.getCPK().getArray());
|
||||||
|
if (secret1 == null || secret1.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the returned secret1 is null, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
byte[] salt = new byte[0];
|
||||||
|
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
||||||
|
if (secret == null || secret.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[genEncryptExchangedKeys] the returned secret is null, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
cUVKeys.put(vFlID, secret);
|
||||||
}
|
}
|
||||||
|
return FLClientStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void encryptShares() throws Exception {
|
private FLClientStatus encryptShares() {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for RequestShareSecrets **************"));
|
LOGGER.info(Common.addTag("[PairWiseMask] ************** generate encrypt share secrets for " +
|
||||||
List<EncryptShare> encryptShareList = new ArrayList<>();
|
"RequestShareSecrets **************"));
|
||||||
// connect sSkUv, bUV, sIndex, indexB and then Encrypt them
|
// connect sSkUv, bUV, sIndex, indexB and then Encrypt them
|
||||||
|
if (sKey.size() < 2) {
|
||||||
|
LOGGER.severe(Common.addTag("[encryptShares] the size of sKey is not valid: < 2, it should be >= 2, " +
|
||||||
|
"please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
List<ShareSecret> sSkUv = genSecretShares(sKey.get(1));
|
List<ShareSecret> sSkUv = genSecretShares(sKey.get(1));
|
||||||
|
if (sSkUv.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[encryptShares] the returned List<ShareSecret> sSkUv is empty, please " +
|
||||||
|
"check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
List<ShareSecret> bUV = genSecretShares(bu);
|
List<ShareSecret> bUV = genSecretShares(bu);
|
||||||
|
if (sSkUv.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[encryptShares] the returned List<ShareSecret> bUV is empty, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
if (sSkUv.size() != bUV.size()) {
|
||||||
|
LOGGER.severe(Common.addTag("[encryptShares] the sSkUv.size() should be equal to bUV.size(), please " +
|
||||||
|
"check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
List<EncryptShare> encryptShareList = new ArrayList<>();
|
||||||
for (int i = 0; i < bUV.size(); i++) {
|
for (int i = 0; i < bUV.size(); i++) {
|
||||||
EncryptShare encryptShare = new EncryptShare();
|
|
||||||
NewArray<byte[]> array = new NewArray<>();
|
|
||||||
String vFlID = bUV.get(i).getFlID();
|
|
||||||
byte[] sShare = sSkUv.get(i).getShare().getArray();
|
byte[] sShare = sSkUv.get(i).getShare().getArray();
|
||||||
byte[] bShare = bUV.get(i).getShare().getArray();
|
byte[] bShare = bUV.get(i).getShare().getArray();
|
||||||
byte[] sIndex = BaseUtil.integer2byteArray(sSkUv.get(i).getIndex());
|
byte[] sIndex = BaseUtil.integer2byteArray(sSkUv.get(i).getIndex());
|
||||||
|
@ -200,53 +295,101 @@ public class CipherClient {
|
||||||
System.arraycopy(sShare, 0, allSecret, 4 + sIndex.length + bIndex.length, sShare.length);
|
System.arraycopy(sShare, 0, allSecret, 4 + sIndex.length + bIndex.length, sShare.length);
|
||||||
System.arraycopy(bShare, 0, allSecret, 4 + sIndex.length + bIndex.length + sShare.length, bShare.length);
|
System.arraycopy(bShare, 0, allSecret, 4 + sIndex.length + bIndex.length + sShare.length, bShare.length);
|
||||||
// encrypt:
|
// encrypt:
|
||||||
byte[] iVecIn = new byte[IVEC_LEN];
|
String vFlID = bUV.get(i).getFlID();
|
||||||
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), iVecIn, "CBC");
|
if (!cUVKeys.containsKey(vFlID)) {
|
||||||
|
LOGGER.severe(Common.addTag("[encryptShares] the key " + vFlID + " is not in map cUVKeys, please " +
|
||||||
|
"check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
AESEncrypt aesEncrypt = new AESEncrypt(cUVKeys.get(vFlID), "CBC");
|
||||||
byte[] encryptData = aesEncrypt.encrypt(cUVKeys.get(vFlID), allSecret);
|
byte[] encryptData = aesEncrypt.encrypt(cUVKeys.get(vFlID), allSecret);
|
||||||
|
if (encryptData == null || encryptData.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[encryptShares] the return byte[] is null, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
NewArray<byte[]> array = new NewArray<>();
|
||||||
array.setSize(encryptData.length);
|
array.setSize(encryptData.length);
|
||||||
array.setArray(encryptData);
|
array.setArray(encryptData);
|
||||||
|
EncryptShare encryptShare = new EncryptShare();
|
||||||
encryptShare.setFlID(vFlID);
|
encryptShare.setFlID(vFlID);
|
||||||
encryptShare.setShare(array);
|
encryptShare.setShare(array);
|
||||||
encryptShareList.add(encryptShare);
|
encryptShareList.add(encryptShare);
|
||||||
}
|
}
|
||||||
setClientShareList(encryptShareList);
|
setClientShareList(encryptShareList);
|
||||||
|
return FLClientStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
public float[] doubleMaskingWeight() throws Exception {
|
/**
|
||||||
int size = u2UClientList.size();
|
* get masked weight of secure aggregation
|
||||||
|
*
|
||||||
|
* @return masked weight
|
||||||
|
*/
|
||||||
|
public float[] doubleMaskingWeight() {
|
||||||
List<Float> noiseBu = new ArrayList<>();
|
List<Float> noiseBu = new ArrayList<>();
|
||||||
random.randomAESCTR(noiseBu, featureSize, bu);
|
int tag = masking.getMasking(noiseBu, featureSize, bu, individualIv);
|
||||||
|
if (tag == -1) {
|
||||||
|
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the return value is -1, please check!"));
|
||||||
|
return new float[0];
|
||||||
|
}
|
||||||
float[] mask = new float[featureSize];
|
float[] mask = new float[featureSize];
|
||||||
for (int i = 0; i < size; i++) {
|
for (String vFlID : u2UClientList) {
|
||||||
String vFlID = u2UClientList.get(i);
|
if (!clientPublicKeyList.containsKey(vFlID)) {
|
||||||
|
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the key " + vFlID + " is not in map " +
|
||||||
|
"clientPublicKeyList, please check!"));
|
||||||
|
return new float[0];
|
||||||
|
}
|
||||||
ClientPublicKey curPublicKey = clientPublicKeyList.get(vFlID);
|
ClientPublicKey curPublicKey = clientPublicKeyList.get(vFlID);
|
||||||
if (localFLParameter.getFlID().equals(vFlID)) {
|
if (localFLParameter.getFlID().equals(vFlID)) {
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
byte[] salt;
|
||||||
|
byte[] iVec;
|
||||||
|
if (vFlID.compareTo(localFLParameter.getFlID()) < 0) {
|
||||||
|
salt = curPublicKey.getPwSalt().getArray();
|
||||||
|
iVec = curPublicKey.getPwIv().getArray();
|
||||||
} else {
|
} else {
|
||||||
byte[] salt = new byte[0];
|
salt = this.pwSalt;
|
||||||
byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray());
|
iVec = this.pwIVec;
|
||||||
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
}
|
||||||
sUVKeys.put(vFlID, secret);
|
if (sKey.size() < 2) {
|
||||||
List<Float> noiseSuv = new ArrayList<>();
|
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the size of sKey is not valid: < 2, it should be " +
|
||||||
random.randomAESCTR(noiseSuv, featureSize, secret);
|
">= 2, please check!"));
|
||||||
int sign;
|
return new float[0];
|
||||||
if (localFLParameter.getFlID().compareTo(vFlID) > 0) {
|
}
|
||||||
sign = 1;
|
byte[] secret1 = keyAgreement.keyAgreement(sKey.get(1), curPublicKey.getSPK().getArray());
|
||||||
} else {
|
if (secret1 == null || secret1.length == 0) {
|
||||||
sign = -1;
|
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the returned secret1 is null, please check!"));
|
||||||
}
|
return new float[0];
|
||||||
for (int j = 0; j < noiseSuv.size(); j++) {
|
}
|
||||||
mask[j] = mask[j] + sign * noiseSuv.get(j);
|
byte[] secret = keyAgreement.getEncryptedPassword(secret1, salt);
|
||||||
}
|
if (secret == null || secret.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the returned secret is null, please check!"));
|
||||||
|
return new float[0];
|
||||||
|
}
|
||||||
|
sUVKeys.put(vFlID, secret);
|
||||||
|
List<Float> noiseSuv = new ArrayList<>();
|
||||||
|
tag = masking.getMasking(noiseSuv, featureSize, secret, iVec);
|
||||||
|
if (tag == -1) {
|
||||||
|
LOGGER.severe(Common.addTag("[doubleMaskingWeight] the return value is -1, please check!"));
|
||||||
|
return new float[0];
|
||||||
|
}
|
||||||
|
int sign;
|
||||||
|
if (localFLParameter.getFlID().compareTo(vFlID) > 0) {
|
||||||
|
sign = 1;
|
||||||
|
} else {
|
||||||
|
sign = -1;
|
||||||
|
}
|
||||||
|
for (int maskIndex = 0; maskIndex < noiseSuv.size(); maskIndex++) {
|
||||||
|
mask[maskIndex] = mask[maskIndex] + sign * noiseSuv.get(maskIndex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int j = 0; j < noiseBu.size(); j++) {
|
for (int maskIndex = 0; maskIndex < noiseBu.size(); maskIndex++) {
|
||||||
mask[j] = mask[j] + noiseBu.get(j);
|
mask[maskIndex] = mask[maskIndex] + noiseBu.get(maskIndex);
|
||||||
}
|
}
|
||||||
return mask;
|
return mask;
|
||||||
}
|
}
|
||||||
|
|
||||||
public NewArray<byte[]> byteToArray(ByteBuffer buf, int size) {
|
private NewArray<byte[]> byteToArray(ByteBuffer buf, int size) {
|
||||||
NewArray<byte[]> newArray = new NewArray<>();
|
NewArray<byte[]> newArray = new NewArray<>();
|
||||||
newArray.setSize(size);
|
newArray.setSize(size);
|
||||||
byte[] array = new byte[size];
|
byte[] array = new byte[size];
|
||||||
|
@ -258,40 +401,80 @@ public class CipherClient {
|
||||||
return newArray;
|
return newArray;
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus requestExchangeKeys() {
|
private FLClientStatus requestExchangeKeys() {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "=============="));
|
LOGGER.info(Common.addTag("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() +
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
"=============="));
|
||||||
genDHKeyPairs();
|
FLClientStatus status = genDHKeyPairs();
|
||||||
|
if (status == FLClientStatus.FAILED) {
|
||||||
|
LOGGER.severe(Common.addTag("[requestExchangeKeys] the return status is FAILED, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
if (cKey.size() <= 0 || sKey.size() <= 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of cKey or sKey is not valid: <=0."));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
if (cKey.size() < 2) {
|
||||||
|
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of cKey is not valid: < 2, it should be >= 2," +
|
||||||
|
" please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
if (sKey.size() < 2) {
|
||||||
|
LOGGER.severe(Common.addTag("[requestExchangeKeys] the size of sKey is not valid: < 2, it should be >= 2," +
|
||||||
|
" please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||||
byte[] cPK = cKey.get(0);
|
byte[] cPK = cKey.get(0);
|
||||||
byte[] sPK = sKey.get(0);
|
byte[] sPK = sKey.get(0);
|
||||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
|
||||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
|
||||||
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
|
int cpk = RequestExchangeKeys.createCPkVector(fbBuilder, cPK);
|
||||||
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
|
int spk = RequestExchangeKeys.createSPkVector(fbBuilder, sPK);
|
||||||
String dateTime = LocalDateTime.now().toString();
|
|
||||||
|
byte[] indIv = new byte[I_VEC_LEN];
|
||||||
|
byte[] pwIv = new byte[I_VEC_LEN];
|
||||||
|
byte[] thisPwSalt = new byte[SALT_SIZE];
|
||||||
|
SecureRandom secureRandom = Common.getSecureRandom();
|
||||||
|
secureRandom.nextBytes(indIv);
|
||||||
|
secureRandom.nextBytes(pwIv);
|
||||||
|
secureRandom.nextBytes(thisPwSalt);
|
||||||
|
this.individualIv = indIv;
|
||||||
|
this.pwIVec = pwIv;
|
||||||
|
this.pwSalt = thisPwSalt;
|
||||||
|
int indIvFbs = RequestExchangeKeys.createIndIvVector(fbBuilder, indIv);
|
||||||
|
int pwIvFbs = RequestExchangeKeys.createPwIvVector(fbBuilder, pwIv);
|
||||||
|
int pwSaltFbs = RequestExchangeKeys.createPwSaltVector(fbBuilder, thisPwSalt);
|
||||||
|
|
||||||
|
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||||
|
Date date = new Date();
|
||||||
|
long timestamp = date.getTime();
|
||||||
|
String dateTime = String.valueOf(timestamp);
|
||||||
int time = fbBuilder.createString(dateTime);
|
int time = fbBuilder.createString(dateTime);
|
||||||
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time);
|
|
||||||
|
int exchangeKeysRoot = RequestExchangeKeys.createRequestExchangeKeys(fbBuilder, id, cpk, spk, iteration, time
|
||||||
|
, indIvFbs, pwIvFbs, pwSaltFbs);
|
||||||
fbBuilder.finish(exchangeKeysRoot);
|
fbBuilder.finish(exchangeKeysRoot);
|
||||||
byte[] msg = fbBuilder.sizedByteArray();
|
byte[] msg = fbBuilder.sizedByteArray();
|
||||||
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||||
|
flParameter.getDomainName());
|
||||||
try {
|
try {
|
||||||
byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg);
|
byte[] responseData = flCommunication.syncRequest(url + "/exchangeKeys", msg);
|
||||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(responseData)) {
|
||||||
LOGGER.info(Common.addTag("[requestExchangeKeys] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[requestExchangeKeys] the server is not ready now, need wait some time and" +
|
||||||
|
" " +
|
||||||
|
"request again"));
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
nextRequestTime = "";
|
nextRequestTime = "";
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
}
|
}
|
||||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||||
ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer);
|
ResponseExchangeKeys responseExchangeKeys = ResponseExchangeKeys.getRootAsResponseExchangeKeys(buffer);
|
||||||
FLClientStatus status = judgeRequestExchangeKeys(responseExchangeKeys);
|
return judgeRequestExchangeKeys(responseExchangeKeys);
|
||||||
return status;
|
} catch (IOException ex) {
|
||||||
} catch (Exception e) {
|
LOGGER.severe(Common.addTag("[requestExchangeKeys] catch IOException: " + ex.getMessage()));
|
||||||
e.printStackTrace();
|
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
|
private FLClientStatus judgeRequestExchangeKeys(ResponseExchangeKeys bufData) {
|
||||||
retCode = bufData.retcode();
|
retCode = bufData.retcode();
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************"));
|
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestExchangeKeys**************"));
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||||
|
@ -303,7 +486,8 @@ public class CipherClient {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success"));
|
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys success"));
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
case (ResponseCode.OutOfTime):
|
case (ResponseCode.OutOfTime):
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request startFLJob again"));
|
LOGGER.info(Common.addTag("[PairWiseMask] RequestExchangeKeys out of time: need wait and request " +
|
||||||
|
"startFLJob again"));
|
||||||
setNextRequestTime(bufData.nextReqTime());
|
setNextRequestTime(bufData.nextReqTime());
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
case (ResponseCode.RequestError):
|
case (ResponseCode.RequestError):
|
||||||
|
@ -311,39 +495,43 @@ public class CipherClient {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys"));
|
LOGGER.info(Common.addTag("[PairWiseMask] catch RequestError or SystemError in RequestExchangeKeys"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
default:
|
default:
|
||||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys is invalid: " + retCode));
|
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseExchangeKeys " +
|
||||||
|
"is invalid: " + retCode));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus getExchangeKeys() {
|
private FLClientStatus getExchangeKeys() {
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
|
||||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||||
String dateTime = LocalDateTime.now().toString();
|
Date date = new Date();
|
||||||
|
long timestamp = date.getTime();
|
||||||
|
String dateTime = String.valueOf(timestamp);
|
||||||
int time = fbBuilder.createString(dateTime);
|
int time = fbBuilder.createString(dateTime);
|
||||||
int getExchangeKeysRoot = GetExchangeKeys.createGetExchangeKeys(fbBuilder, id, iteration, time);
|
int getExchangeKeysRoot = GetExchangeKeys.createGetExchangeKeys(fbBuilder, id, iteration, time);
|
||||||
fbBuilder.finish(getExchangeKeysRoot);
|
fbBuilder.finish(getExchangeKeysRoot);
|
||||||
byte[] msg = fbBuilder.sizedByteArray();
|
byte[] msg = fbBuilder.sizedByteArray();
|
||||||
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||||
|
flParameter.getDomainName());
|
||||||
try {
|
try {
|
||||||
byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg);
|
byte[] responseData = flCommunication.syncRequest(url + "/getKeys", msg);
|
||||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(responseData)) {
|
||||||
LOGGER.info(Common.addTag("[getExchangeKeys] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[getExchangeKeys] the server is not ready now, need wait some time and " +
|
||||||
|
"request again"));
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
nextRequestTime = "";
|
nextRequestTime = "";
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
}
|
}
|
||||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||||
ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer);
|
ReturnExchangeKeys returnExchangeKeys = ReturnExchangeKeys.getRootAsReturnExchangeKeys(buffer);
|
||||||
FLClientStatus status = judgeGetExchangeKeys(returnExchangeKeys);
|
return judgeGetExchangeKeys(returnExchangeKeys);
|
||||||
return status;
|
} catch (IOException ex) {
|
||||||
} catch (Exception e) {
|
LOGGER.severe(Common.addTag("[getExchangeKeys] catch IOException: " + ex.getMessage()));
|
||||||
e.printStackTrace();
|
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
|
private FLClientStatus judgeGetExchangeKeys(ReturnExchangeKeys bufData) {
|
||||||
retCode = bufData.retcode();
|
retCode = bufData.retcode();
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************"));
|
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetExchangeKeys**************"));
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||||
|
@ -363,17 +551,25 @@ public class CipherClient {
|
||||||
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
|
int sizeCpk = bufData.remotePublickeys(i).cPkLength();
|
||||||
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
|
ByteBuffer bufSpk = bufData.remotePublickeys(i).sPkAsByteBuffer();
|
||||||
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
|
int sizeSpk = bufData.remotePublickeys(i).sPkLength();
|
||||||
|
ByteBuffer bufPwIv = bufData.remotePublickeys(i).pwIvAsByteBuffer();
|
||||||
|
int sizePwIv = bufData.remotePublickeys(i).pwIvLength();
|
||||||
|
ByteBuffer bufPwSalt = bufData.remotePublickeys(i).pwSaltAsByteBuffer();
|
||||||
|
int sizePwSalt = bufData.remotePublickeys(i).pwSaltLength();
|
||||||
publicKey.setCPK(byteToArray(bufCpk, sizeCpk));
|
publicKey.setCPK(byteToArray(bufCpk, sizeCpk));
|
||||||
publicKey.setSPK(byteToArray(bufSpk, sizeSpk));
|
publicKey.setSPK(byteToArray(bufSpk, sizeSpk));
|
||||||
|
publicKey.setPwIv(byteToArray(bufPwIv, sizePwIv));
|
||||||
|
publicKey.setPwSalt(byteToArray(bufPwSalt, sizePwSalt));
|
||||||
clientPublicKeyList.put(srcFlId, publicKey);
|
clientPublicKeyList.put(srcFlId, publicKey);
|
||||||
u1ClientList.add(srcFlId);
|
u1ClientList.add(srcFlId);
|
||||||
}
|
}
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
case (ResponseCode.SucNotReady):
|
case (ResponseCode.SucNotReady):
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetExchangeKeys again!"));
|
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
|
||||||
|
"GetExchangeKeys again!"));
|
||||||
return FLClientStatus.WAIT;
|
return FLClientStatus.WAIT;
|
||||||
case (ResponseCode.OutOfTime):
|
case (ResponseCode.OutOfTime):
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request startFLJob again"));
|
LOGGER.info(Common.addTag("[PairWiseMask] GetExchangeKeys out of time: need wait and request " +
|
||||||
|
"startFLJob again"));
|
||||||
setNextRequestTime(bufData.nextReqTime());
|
setNextRequestTime(bufData.nextReqTime());
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
case (ResponseCode.RequestError):
|
case (ResponseCode.RequestError):
|
||||||
|
@ -381,32 +577,48 @@ public class CipherClient {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys"));
|
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetExchangeKeys"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
default:
|
default:
|
||||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is invalid: " + retCode));
|
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnExchangeKeys is" +
|
||||||
|
" invalid: " + retCode));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus requestShareSecrets() throws Exception {
|
private FLClientStatus requestShareSecrets() {
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
FLClientStatus status = genIndividualSecret();
|
||||||
genIndividualSecret();
|
if (status == FLClientStatus.FAILED) {
|
||||||
genEncryptExchangedKeys();
|
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from genIndividualSecret" +
|
||||||
encryptShares();
|
"(), please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
status = genEncryptExchangedKeys();
|
||||||
|
if (status == FLClientStatus.FAILED) {
|
||||||
|
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from " +
|
||||||
|
"genEncryptExchangedKeys(), please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
status = encryptShares();
|
||||||
|
if (status == FLClientStatus.FAILED) {
|
||||||
|
LOGGER.severe(Common.addTag("[requestShareSecrets] the returned status is FAILED from encryptShares(), " +
|
||||||
|
"please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||||
String dateTime = LocalDateTime.now().toString();
|
Date date = new Date();
|
||||||
|
long timestamp = date.getTime();
|
||||||
|
String dateTime = String.valueOf(timestamp);
|
||||||
int time = fbBuilder.createString(dateTime);
|
int time = fbBuilder.createString(dateTime);
|
||||||
int clientShareSize = clientShareList.size();
|
int clientShareSize = clientShareList.size();
|
||||||
if (clientShareSize <= 0) {
|
if (clientShareSize <= 0) {
|
||||||
LOGGER.warning(Common.addTag("[PairWiseMask] encrypt shares is not ready now!"));
|
LOGGER.warning(Common.addTag("[PairWiseMask] encrypt shares is not ready now!"));
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
FLClientStatus status = requestShareSecrets();
|
return requestShareSecrets();
|
||||||
return status;
|
|
||||||
} else {
|
} else {
|
||||||
int[] add = new int[clientShareSize];
|
int[] add = new int[clientShareSize];
|
||||||
for (int i = 0; i < clientShareSize; i++) {
|
for (int i = 0; i < clientShareSize; i++) {
|
||||||
int flID = fbBuilder.createString(clientShareList.get(i).getFlID());
|
int flID = fbBuilder.createString(clientShareList.get(i).getFlID());
|
||||||
int shareSecretFbs = ClientShare.createShareVector(fbBuilder, clientShareList.get(i).getShare().getArray());
|
int shareSecretFbs = ClientShare.createShareVector(fbBuilder,
|
||||||
|
clientShareList.get(i).getShare().getArray());
|
||||||
ClientShare.startClientShare(fbBuilder);
|
ClientShare.startClientShare(fbBuilder);
|
||||||
ClientShare.addFlId(fbBuilder, flID);
|
ClientShare.addFlId(fbBuilder, flID);
|
||||||
ClientShare.addShare(fbBuilder, shareSecretFbs);
|
ClientShare.addShare(fbBuilder, shareSecretFbs);
|
||||||
|
@ -414,29 +626,33 @@ public class CipherClient {
|
||||||
add[i] = clientShareRoot;
|
add[i] = clientShareRoot;
|
||||||
}
|
}
|
||||||
int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add);
|
int encryptedSharesFbs = RequestShareSecrets.createEncryptedSharesVector(fbBuilder, add);
|
||||||
int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id, encryptedSharesFbs, iteration, time);
|
int requestShareSecretsRoot = RequestShareSecrets.createRequestShareSecrets(fbBuilder, id,
|
||||||
|
encryptedSharesFbs, iteration, time);
|
||||||
fbBuilder.finish(requestShareSecretsRoot);
|
fbBuilder.finish(requestShareSecretsRoot);
|
||||||
byte[] msg = fbBuilder.sizedByteArray();
|
byte[] msg = fbBuilder.sizedByteArray();
|
||||||
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||||
|
flParameter.getDomainName());
|
||||||
try {
|
try {
|
||||||
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
|
byte[] responseData = flCommunication.syncRequest(url + "/shareSecrets", msg);
|
||||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(responseData)) {
|
||||||
LOGGER.info(Common.addTag("[requestShareSecrets] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[requestShareSecrets] the server is not ready now, need wait some time" +
|
||||||
|
" " +
|
||||||
|
"and request again"));
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
nextRequestTime = "";
|
nextRequestTime = "";
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
}
|
}
|
||||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||||
ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer);
|
ResponseShareSecrets responseShareSecrets = ResponseShareSecrets.getRootAsResponseShareSecrets(buffer);
|
||||||
FLClientStatus status = judgeRequestShareSecrets(responseShareSecrets);
|
return judgeRequestShareSecrets(responseShareSecrets);
|
||||||
return status;
|
} catch (IOException ex) {
|
||||||
} catch (Exception e) {
|
LOGGER.severe(Common.addTag("[requestShareSecrets] catch IOException: " + ex.getMessage()));
|
||||||
e.printStackTrace();
|
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
|
private FLClientStatus judgeRequestShareSecrets(ResponseShareSecrets bufData) {
|
||||||
retCode = bufData.retcode();
|
retCode = bufData.retcode();
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************"));
|
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of RequestShareSecrets**************"));
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||||
|
@ -448,7 +664,8 @@ public class CipherClient {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success"));
|
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets success"));
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
case (ResponseCode.OutOfTime):
|
case (ResponseCode.OutOfTime):
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request startFLJob again"));
|
LOGGER.info(Common.addTag("[PairWiseMask] RequestShareSecrets out of time: need wait and request " +
|
||||||
|
"startFLJob again"));
|
||||||
setNextRequestTime(bufData.nextReqTime());
|
setNextRequestTime(bufData.nextReqTime());
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
case (ResponseCode.RequestError):
|
case (ResponseCode.RequestError):
|
||||||
|
@ -456,39 +673,43 @@ public class CipherClient {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets"));
|
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in RequestShareSecrets"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
default:
|
default:
|
||||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets is invalid: " + retCode));
|
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ResponseShareSecrets " +
|
||||||
|
"is invalid: " + retCode));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus getShareSecrets() {
|
private FLClientStatus getShareSecrets() {
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
|
||||||
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
FlatBufferBuilder fbBuilder = new FlatBufferBuilder();
|
||||||
int id = fbBuilder.createString(localFLParameter.getFlID());
|
int id = fbBuilder.createString(localFLParameter.getFlID());
|
||||||
String dateTime = LocalDateTime.now().toString();
|
Date date = new Date();
|
||||||
|
long timestamp = date.getTime();
|
||||||
|
String dateTime = String.valueOf(timestamp);
|
||||||
int time = fbBuilder.createString(dateTime);
|
int time = fbBuilder.createString(dateTime);
|
||||||
int getShareSecrets = GetShareSecrets.createGetShareSecrets(fbBuilder, id, iteration, time);
|
int getShareSecrets = GetShareSecrets.createGetShareSecrets(fbBuilder, id, iteration, time);
|
||||||
fbBuilder.finish(getShareSecrets);
|
fbBuilder.finish(getShareSecrets);
|
||||||
byte[] msg = fbBuilder.sizedByteArray();
|
byte[] msg = fbBuilder.sizedByteArray();
|
||||||
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||||
|
flParameter.getDomainName());
|
||||||
try {
|
try {
|
||||||
byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg);
|
byte[] responseData = flCommunication.syncRequest(url + "/getSecrets", msg);
|
||||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(responseData)) {
|
||||||
LOGGER.info(Common.addTag("[getShareSecrets] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[getShareSecrets] the server is not ready now, need wait some time and " +
|
||||||
|
"request again"));
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
nextRequestTime = "";
|
nextRequestTime = "";
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
}
|
}
|
||||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||||
ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer);
|
ReturnShareSecrets returnShareSecrets = ReturnShareSecrets.getRootAsReturnShareSecrets(buffer);
|
||||||
FLClientStatus status = judgeGetShareSecrets(returnShareSecrets);
|
return judgeGetShareSecrets(returnShareSecrets);
|
||||||
return status;
|
} catch (IOException ex) {
|
||||||
} catch (Exception e) {
|
LOGGER.severe(Common.addTag("[getShareSecrets] catch IOException: " + ex.getMessage()));
|
||||||
e.printStackTrace();
|
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
|
private FLClientStatus judgeGetShareSecrets(ReturnShareSecrets bufData) {
|
||||||
retCode = bufData.retcode();
|
retCode = bufData.retcode();
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************"));
|
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of GetShareSecrets**************"));
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||||
|
@ -503,20 +724,26 @@ public class CipherClient {
|
||||||
int length = bufData.encryptedSharesLength();
|
int length = bufData.encryptedSharesLength();
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
EncryptShare shareSecret = new EncryptShare();
|
EncryptShare shareSecret = new EncryptShare();
|
||||||
shareSecret.setFlID(bufData.encryptedShares(i).flId());
|
ClientShare clientShare = bufData.encryptedShares(i);
|
||||||
ByteBuffer bufShare = bufData.encryptedShares(i).shareAsByteBuffer();
|
if (clientShare == null) {
|
||||||
int sizeShare = bufData.encryptedShares(i).shareLength();
|
LOGGER.severe(Common.addTag("[PairWiseMask] the clientShare returned from server is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
shareSecret.setFlID(clientShare.flId());
|
||||||
|
ByteBuffer bufShare = clientShare.shareAsByteBuffer();
|
||||||
|
int sizeShare = clientShare.shareLength();
|
||||||
shareSecret.setShare(byteToArray(bufShare, sizeShare));
|
shareSecret.setShare(byteToArray(bufShare, sizeShare));
|
||||||
returnShareList.add(shareSecret);
|
returnShareList.add(shareSecret);
|
||||||
u2UClientList.add(bufData.encryptedShares(i).flId());
|
u2UClientList.add(clientShare.flId());
|
||||||
}
|
}
|
||||||
|
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
case (ResponseCode.SucNotReady):
|
case (ResponseCode.SucNotReady):
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetShareSecrets again!"));
|
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
|
||||||
|
"GetShareSecrets again!"));
|
||||||
return FLClientStatus.WAIT;
|
return FLClientStatus.WAIT;
|
||||||
case (ResponseCode.OutOfTime):
|
case (ResponseCode.OutOfTime):
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request startFLJob again"));
|
LOGGER.info(Common.addTag("[PairWiseMask] GetShareSecrets out of time: need wait and request " +
|
||||||
|
"startFLJob again"));
|
||||||
setNextRequestTime(bufData.nextReqTime());
|
setNextRequestTime(bufData.nextReqTime());
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
case (ResponseCode.RequestError):
|
case (ResponseCode.RequestError):
|
||||||
|
@ -524,15 +751,22 @@ public class CipherClient {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets"));
|
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetShareSecrets"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
default:
|
default:
|
||||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is invalid: " + retCode));
|
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnShareSecrets is" +
|
||||||
|
" invalid: " + retCode));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* exchangeKeys round of secure aggregation
|
||||||
|
*
|
||||||
|
* @return round execution result
|
||||||
|
*/
|
||||||
public FLClientStatus exchangeKeys() {
|
public FLClientStatus exchangeKeys() {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys ======================"));
|
LOGGER.info(Common.addTag("[PairWiseMask] ==================== round0: RequestExchangeKeys+GetExchangeKeys " +
|
||||||
FLClientStatus curStatus;
|
"======================"));
|
||||||
// RequestExchangeKeys
|
// RequestExchangeKeys
|
||||||
|
FLClientStatus curStatus;
|
||||||
curStatus = requestExchangeKeys();
|
curStatus = requestExchangeKeys();
|
||||||
while (curStatus == FLClientStatus.WAIT) {
|
while (curStatus == FLClientStatus.WAIT) {
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
|
@ -551,8 +785,14 @@ public class CipherClient {
|
||||||
return curStatus;
|
return curStatus;
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus shareSecrets() throws Exception {
|
/**
|
||||||
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets ======================")));
|
* shareSecrets round of secure aggregation
|
||||||
|
*
|
||||||
|
* @return round execution result
|
||||||
|
*/
|
||||||
|
public FLClientStatus shareSecrets() {
|
||||||
|
LOGGER.info(Common.addTag(("[PairWiseMask] ==================== round1: RequestShareSecrets+GetShareSecrets " +
|
||||||
|
"======================")));
|
||||||
FLClientStatus curStatus;
|
FLClientStatus curStatus;
|
||||||
// RequestShareSecrets
|
// RequestShareSecrets
|
||||||
curStatus = requestShareSecrets();
|
curStatus = requestShareSecrets();
|
||||||
|
@ -573,14 +813,22 @@ public class CipherClient {
|
||||||
return curStatus;
|
return curStatus;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* reconstructSecrets round of secure aggregation
|
||||||
|
*
|
||||||
|
* @return round execution result
|
||||||
|
*/
|
||||||
public FLClientStatus reconstructSecrets() {
|
public FLClientStatus reconstructSecrets() {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret ========================"));
|
LOGGER.info(Common.addTag("[PairWiseMask] =================== round3: GetClientList+SendReconstructSecret " +
|
||||||
|
"========================"));
|
||||||
FLClientStatus curStatus;
|
FLClientStatus curStatus;
|
||||||
// GetClientList
|
// GetClientList
|
||||||
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
|
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList,
|
||||||
|
cUVKeys);
|
||||||
while (curStatus == FLClientStatus.WAIT) {
|
while (curStatus == FLClientStatus.WAIT) {
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList, cUVKeys);
|
curStatus = clientListReq.getClientList(iteration, u3ClientList, decryptShareSecretsList, returnShareList
|
||||||
|
, cUVKeys);
|
||||||
}
|
}
|
||||||
if (curStatus == FLClientStatus.RESTART) {
|
if (curStatus == FLClientStatus.RESTART) {
|
||||||
nextRequestTime = clientListReq.getNextRequestTime();
|
nextRequestTime = clientListReq.getNextRequestTime();
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,10 +13,17 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
import org.bouncycastle.crypto.BlockCipher;
|
||||||
|
import org.bouncycastle.crypto.engines.AESEngine;
|
||||||
|
import org.bouncycastle.crypto.prng.SP800SecureRandomBuilder;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import java.security.SecureRandom;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
@ -26,28 +33,94 @@ import java.util.logging.Logger;
|
||||||
import java.util.regex.Matcher;
|
import java.util.regex.Matcher;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define basic global methods used in federated learning task.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class Common {
|
public class Common {
|
||||||
|
/**
|
||||||
|
* Global logger title.
|
||||||
|
*/
|
||||||
public static final String LOG_TITLE = "<FLClient> ";
|
public static final String LOG_TITLE = "<FLClient> ";
|
||||||
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
|
|
||||||
private static List<String> flNameTrustList = new ArrayList<>(Arrays.asList("lenet", "albert"));
|
|
||||||
|
|
||||||
public static String generateUrl(boolean useHttps, boolean useElb, String ip, int port, int serverNum) {
|
/**
|
||||||
if (useHttps) {
|
* The list of trust flName.
|
||||||
ip = "https://" + ip + ":";
|
*/
|
||||||
} else {
|
public static final List<String> FL_NAME_TRUST_LIST = new ArrayList<>(Arrays.asList("lenet", "albert"));
|
||||||
ip = "http://" + ip + ":";
|
|
||||||
|
/**
|
||||||
|
* The list of trust ssl protocol.
|
||||||
|
*/
|
||||||
|
public static final List<String> SSL_PROTOCOL_TRUST_LIST = new ArrayList<>(Arrays.asList("TLSv1.3", "TLSv1.2"));
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The tag when server is in safe mode.
|
||||||
|
*/
|
||||||
|
public static final String SAFE_MOD = "The cluster is in safemode.";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The tag when server is not ready.
|
||||||
|
*/
|
||||||
|
public static final String JOB_NOT_AVAILABLE = "The server's training job is disabled or finished.";
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(Common.class.toString());
|
||||||
|
private static SecureRandom secureRandom;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate the URL for device-sever interaction
|
||||||
|
*
|
||||||
|
* @param ifUseElb whether a client randomly sends a request to a server address within a specified range.
|
||||||
|
* @param serverNum number of servers that can send requests.
|
||||||
|
* @param domainName the URL for device-sever interaction set by user.
|
||||||
|
* @return the URL for device-sever interaction.
|
||||||
|
*/
|
||||||
|
public static String generateUrl(boolean ifUseElb, int serverNum, String domainName) {
|
||||||
|
if (serverNum <= 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[generateUrl] the input argument <serverNum> is not valid: <= 0, it should " +
|
||||||
|
"be > 0, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
String url;
|
String url;
|
||||||
if (useElb) {
|
if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
|
||||||
|
LOGGER.severe(Common.addTag("[generateUrl] the input argument <domainName> is null or not valid, it " +
|
||||||
|
"should be like as https://...... or http://...... , please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if (ifUseElb) {
|
||||||
|
if (domainName.split("//")[1].split(":").length < 2) {
|
||||||
|
LOGGER.severe(Common.addTag("[generateUrl] the format of <domainName> is not valid, it should be like" +
|
||||||
|
" as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
String ip = domainName.split("//")[1].split(":")[0];
|
||||||
|
int port = Integer.parseInt(domainName.split("//")[1].split(":")[1]);
|
||||||
|
if (!Common.checkIP(ip)) {
|
||||||
|
LOGGER.severe(Common.addTag("[generateUrl] the <ip> split from domainName is not valid, domainName " +
|
||||||
|
"should be like as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, " +
|
||||||
|
"please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if (!Common.checkPort(port)) {
|
||||||
|
LOGGER.severe(Common.addTag("[generateUrl] the <port> split from domainName is not valid, domainName " +
|
||||||
|
"should be like as https://127.0.0.1:6666 or http://127.0.0.1:6666 when set useElb to true, " +
|
||||||
|
"please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
String tag = domainName.split("//")[0] + "//";
|
||||||
Random rand = new Random();
|
Random rand = new Random();
|
||||||
int randomNum = rand.nextInt(100000) % serverNum + port;
|
int randomNum = rand.nextInt(100000) % serverNum + port;
|
||||||
url = ip + String.valueOf(randomNum);
|
url = tag + ip + ":" + String.valueOf(randomNum);
|
||||||
} else {
|
} else {
|
||||||
url = ip + String.valueOf(port);
|
url = domainName;
|
||||||
}
|
}
|
||||||
return url;
|
return url;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Store weight name of classifier to a list.
|
||||||
|
*
|
||||||
|
* @param classifierWeightName the list to store weight name of classifier.
|
||||||
|
*/
|
||||||
public static void setClassifierWeightName(List<String> classifierWeightName) {
|
public static void setClassifierWeightName(List<String> classifierWeightName) {
|
||||||
classifierWeightName.add("albert.pooler.weight");
|
classifierWeightName.add("albert.pooler.weight");
|
||||||
classifierWeightName.add("albert.pooler.bias");
|
classifierWeightName.add("albert.pooler.bias");
|
||||||
|
@ -56,6 +129,11 @@ public class Common {
|
||||||
LOGGER.info(addTag("classifierWeightName size: " + classifierWeightName.size()));
|
LOGGER.info(addTag("classifierWeightName size: " + classifierWeightName.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Store weight name of albert network to a list.
|
||||||
|
*
|
||||||
|
* @param albertWeightName the list to store weight name of albert network.
|
||||||
|
*/
|
||||||
public static void setAlbertWeightName(List<String> albertWeightName) {
|
public static void setAlbertWeightName(List<String> albertWeightName) {
|
||||||
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.weight");
|
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.weight");
|
||||||
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.bias");
|
albertWeightName.add("albert.encoder.embedding_hidden_mapping_in.bias");
|
||||||
|
@ -78,32 +156,67 @@ public class Common {
|
||||||
LOGGER.info(addTag("albertWeightName size: " + albertWeightName.size()));
|
LOGGER.info(addTag("albertWeightName size: " + albertWeightName.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether the flName set by user is in the trust list.
|
||||||
|
*
|
||||||
|
* @param flName the model name set by user.
|
||||||
|
* @return boolean value, true indicates the flName set by user is valid, false indicates the flName set by user
|
||||||
|
* is not valid.
|
||||||
|
*/
|
||||||
public static boolean checkFLName(String flName) {
|
public static boolean checkFLName(String flName) {
|
||||||
return (flNameTrustList.contains(flName));
|
return (FL_NAME_TRUST_LIST.contains(flName));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether the sslProtocol set by user is in the trust list.
|
||||||
|
*
|
||||||
|
* @param sslProtocol the ssl protocol set by user.
|
||||||
|
* @return boolean value, true indicates the sslProtocol set by user is valid, false indicates the sslProtocol
|
||||||
|
* set by user is not valid.
|
||||||
|
*/
|
||||||
|
public static boolean checkSSLProtocol(String sslProtocol) {
|
||||||
|
return (SSL_PROTOCOL_TRUST_LIST.contains(sslProtocol));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The program waits for the specified time and then to continue.
|
||||||
|
*
|
||||||
|
* @param millis the waiting time (ms).
|
||||||
|
*/
|
||||||
public static void sleep(long millis) {
|
public static void sleep(long millis) {
|
||||||
try {
|
try {
|
||||||
Thread.sleep(millis); //1000 milliseconds is one second.
|
Thread.sleep(millis); // 1000 milliseconds is one second.
|
||||||
} catch (InterruptedException ex) {
|
} catch (InterruptedException ex) {
|
||||||
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
|
LOGGER.severe(addTag("[sleep] catch InterruptedException: " + ex.getMessage()));
|
||||||
Thread.currentThread().interrupt();
|
Thread.currentThread().interrupt();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the waiting time for repeated requests.
|
||||||
|
*
|
||||||
|
* @param nextRequestTime the timestamp return from server.
|
||||||
|
* @return the waiting time for repeated requests.
|
||||||
|
*/
|
||||||
public static long getWaitTime(String nextRequestTime) {
|
public static long getWaitTime(String nextRequestTime) {
|
||||||
|
|
||||||
Date date = new Date();
|
Date date = new Date();
|
||||||
long currentTime = date.getTime();
|
long currentTime = date.getTime();
|
||||||
long waitTime = 0;
|
long waitTime = 0L;
|
||||||
if (!("").equals(nextRequestTime)) {
|
if (!(nextRequestTime == null || nextRequestTime.isEmpty())) {
|
||||||
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
|
waitTime = Math.max(0, Long.valueOf(nextRequestTime) - currentTime);
|
||||||
}
|
}
|
||||||
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " + currentTime));
|
LOGGER.info(addTag("[getWaitTime] next request time stamp: " + nextRequestTime + " current time stamp: " +
|
||||||
|
currentTime));
|
||||||
LOGGER.info(addTag("[getWaitTime] waitTime: " + waitTime));
|
LOGGER.info(addTag("[getWaitTime] waitTime: " + waitTime));
|
||||||
return waitTime;
|
return waitTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get start time.
|
||||||
|
*
|
||||||
|
* @param tag the tag added to the logger.
|
||||||
|
* @return start time.
|
||||||
|
*/
|
||||||
public static long startTime(String tag) {
|
public static long startTime(String tag) {
|
||||||
Date startDate = new Date();
|
Date startDate = new Date();
|
||||||
long startTime = startDate.getTime();
|
long startTime = startDate.getTime();
|
||||||
|
@ -111,6 +224,12 @@ public class Common {
|
||||||
return startTime;
|
return startTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get end time.
|
||||||
|
*
|
||||||
|
* @param start the start time.
|
||||||
|
* @param tag the tag added to the logger.
|
||||||
|
*/
|
||||||
public static void endTime(long start, String tag) {
|
public static void endTime(long start, String tag) {
|
||||||
Date endDate = new Date();
|
Date endDate = new Date();
|
||||||
long endTime = endDate.getTime();
|
long endTime = endDate.getTime();
|
||||||
|
@ -118,53 +237,182 @@ public class Common {
|
||||||
LOGGER.info(addTag("[interval time] <" + tag + "> interval time(ms): " + (endTime - start)));
|
LOGGER.info(addTag("[interval time] <" + tag + "> interval time(ms): " + (endTime - start)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add specified tag to the message.
|
||||||
|
*
|
||||||
|
* @param message the message need to add tag.
|
||||||
|
* @return the message after adding tag.
|
||||||
|
*/
|
||||||
public static String addTag(String message) {
|
public static String addTag(String message) {
|
||||||
return LOG_TITLE + message;
|
return LOG_TITLE + message;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static boolean isSafeMod(byte[] message, String safeModTag) {
|
/**
|
||||||
return (new String(message)).contains(safeModTag);
|
* Check whether the server is ready based on the message returned by the server.
|
||||||
|
*
|
||||||
|
* @param message the message returned by the server..
|
||||||
|
* @return boolean value, true indicates the server is ready, false indicates the server is not ready.
|
||||||
|
*/
|
||||||
|
public static boolean isSeverReady(byte[] message) {
|
||||||
|
if (message == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[isSeverReady] the input argument <message> is null, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
String messageStr = new String(message);
|
||||||
|
if (messageStr.contains(SAFE_MOD)) {
|
||||||
|
LOGGER.info(Common.addTag("[isSeverReady] " + SAFE_MOD + ", need wait some time and request again"));
|
||||||
|
return false;
|
||||||
|
} else if (messageStr.contains(JOB_NOT_AVAILABLE)) {
|
||||||
|
LOGGER.info(Common.addTag("[isSeverReady] " + JOB_NOT_AVAILABLE + ", need wait some time and request " +
|
||||||
|
"again"));
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String getRealPath (String path) {
|
/**
|
||||||
LOGGER.info(addTag("[original path] " + path));
|
* Convert a user-set path to a standard path.
|
||||||
|
*
|
||||||
|
* @param path the user-set path.
|
||||||
|
* @return the standard path.
|
||||||
|
*/
|
||||||
|
public static String getRealPath(String path) {
|
||||||
|
if (path == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[getRealPath] the input argument <path> is null, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
LOGGER.info(addTag("[getRealPath] original path: " + path));
|
||||||
String[] paths = path.split(",");
|
String[] paths = path.split(",");
|
||||||
for (int i = 0; i < paths.length; i++) {
|
for (int i = 0; i < paths.length; i++) {
|
||||||
LOGGER.info(addTag("[original path " + i + "] " + paths[i]));
|
if (paths[i] == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[getRealPath] the paths[i] is null, please check"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
LOGGER.info(addTag("[getRealPath] original path " + i + ": " + paths[i]));
|
||||||
File file = new File(paths[i]);
|
File file = new File(paths[i]);
|
||||||
try {
|
try {
|
||||||
paths[i] = file.getCanonicalPath();
|
paths[i] = file.getCanonicalPath();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
LOGGER.severe(addTag("[checkPath] catch IOException in file.getCanonicalPath(): " + e.getMessage()));
|
LOGGER.severe(addTag("[getRealPath] catch IOException in file.getCanonicalPath(): " + e.getMessage()));
|
||||||
throw new RuntimeException();
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
path = String.join(",", Arrays.asList(paths));
|
String realPath = String.join(",", Arrays.asList(paths));
|
||||||
LOGGER.info(addTag("[real path] " + path));
|
LOGGER.info(addTag("[getRealPath] real path: " + realPath));
|
||||||
return path;
|
return realPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether the path set by user exists.
|
||||||
|
*
|
||||||
|
* @param path the path set by user.
|
||||||
|
* @return boolean value, true indicates the path is exist, false indicates the path is not exist
|
||||||
|
*/
|
||||||
public static boolean checkPath(String path) {
|
public static boolean checkPath(String path) {
|
||||||
boolean tag = true;
|
if (path == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[checkPath] the input argument <path> is null, please check!"));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
String[] paths = path.split(",");
|
String[] paths = path.split(",");
|
||||||
for (int i = 0; i < paths.length; i++) {
|
for (int i = 0; i < paths.length; i++) {
|
||||||
|
if (paths[i] == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[checkPath] the paths[i] is null, please check"));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
LOGGER.info(addTag("[check path " + i + "] " + paths[i]));
|
LOGGER.info(addTag("[check path " + i + "] " + paths[i]));
|
||||||
File file = new File(paths[i]);
|
File file = new File(paths[i]);
|
||||||
if (!file.exists()) {
|
if (!file.exists()) {
|
||||||
tag = false;
|
LOGGER.severe(Common.addTag("[checkPath] the path is not exist, please check"));
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return tag;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether the ip set by user is valid.
|
||||||
|
*
|
||||||
|
* @param ip the ip set by user.
|
||||||
|
* @return boolean value, true indicates the ip is valid, false indicates the ip is not valid.
|
||||||
|
*/
|
||||||
public static boolean checkIP(String ip) {
|
public static boolean checkIP(String ip) {
|
||||||
String regex = "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.](25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.](25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])";
|
if (ip == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[checkIP] the input argument <ip> is null, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
String regex = "(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])[.]" +
|
||||||
|
"(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.]" +
|
||||||
|
"(25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[0-9])[.]" +
|
||||||
|
"(25[0-4]|2[0-4][0-9]|1[0-9][0-9]|[1-9][0-9]|[1-9])";
|
||||||
Pattern pattern = Pattern.compile(regex);
|
Pattern pattern = Pattern.compile(regex);
|
||||||
Matcher matcher = pattern.matcher(ip);
|
Matcher matcher = pattern.matcher(ip);
|
||||||
return matcher.matches();
|
return matcher.matches();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check whether the port set by user is valid.
|
||||||
|
*
|
||||||
|
* @param port the port set by user.
|
||||||
|
* @return boolean value, true indicates the port is valid, false indicates the port is not valid.
|
||||||
|
*/
|
||||||
public static boolean checkPort(int port) {
|
public static boolean checkPort(int port) {
|
||||||
return port > 0 && port <= 65535;
|
return port > 0 && port <= 65535;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain secure random.
|
||||||
|
*
|
||||||
|
* @return the secure random.
|
||||||
|
*/
|
||||||
|
public static SecureRandom getSecureRandom() {
|
||||||
|
if (secureRandom == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[setSecureRandom] the parameter secureRandom is null, please set it before " +
|
||||||
|
"use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
return secureRandom;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the secure random to parameter secureRandom of the class Common.
|
||||||
|
*
|
||||||
|
* @param secureRandom the secure random.
|
||||||
|
*/
|
||||||
|
public static void setSecureRandom(SecureRandom secureRandom) {
|
||||||
|
if (secureRandom == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[setSecureRandom] the input parameter secureRandom is null, please check"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
Common.secureRandom = secureRandom;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain fast secure random.
|
||||||
|
*
|
||||||
|
* @return the fast secure random.
|
||||||
|
*/
|
||||||
|
public static SecureRandom getFastSecureRandom() {
|
||||||
|
try {
|
||||||
|
LOGGER.info(Common.addTag("[getFastSecureRandom] start create fastSecureRandom"));
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
|
SecureRandom blockingRandom = SecureRandom.getInstanceStrong();
|
||||||
|
boolean ifPredictionResistant = true;
|
||||||
|
BlockCipher cipher = new AESEngine();
|
||||||
|
int cipherLen = 256;
|
||||||
|
int entropyBitsRequired = 384;
|
||||||
|
byte[] nonce = null;
|
||||||
|
boolean ifForceReseed = false;
|
||||||
|
SecureRandom fastRandom = new SP800SecureRandomBuilder(blockingRandom, ifPredictionResistant)
|
||||||
|
.setEntropyBitsRequired(entropyBitsRequired)
|
||||||
|
.buildCTR(cipher, cipherLen, nonce, ifForceReseed);
|
||||||
|
fastRandom.nextInt();
|
||||||
|
LOGGER.info(Common.addTag("[getFastSecureRandom] finish create fastSecureRandom"));
|
||||||
|
LOGGER.info(Common.addTag("[getFastSecureRandom] cost time: " + (System.currentTimeMillis() - start)));
|
||||||
|
return fastRandom;
|
||||||
|
} catch (NoSuchAlgorithmException e) {
|
||||||
|
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException: " + e.getMessage()));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,13 +13,17 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The early stop mod.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public enum EarlyStopMod {
|
public enum EarlyStopMod {
|
||||||
LOSS_DIFF,
|
LOSS_DIFF,
|
||||||
LOSS_ABS,
|
LOSS_ABS,
|
||||||
WEIGHT_DIFF,
|
WEIGHT_DIFF,
|
||||||
NOT_EARLY_STOP
|
NOT_EARLY_STOP
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,12 +13,16 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Security encryption level.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public enum EncryptLevel {
|
public enum EncryptLevel {
|
||||||
PW_ENCRYPT,
|
PW_ENCRYPT,
|
||||||
DP_ENCRYPT,
|
DP_ENCRYPT,
|
||||||
NOT_ENCRYPT
|
NOT_ENCRYPT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,21 +1,26 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
* You may obtain a copy of the License at
|
* You may obtain a copy of the License at
|
||||||
*
|
*
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The status code of federated learning.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public enum FLClientStatus {
|
public enum FLClientStatus {
|
||||||
SUCCESS,
|
SUCCESS,
|
||||||
FAILED,
|
FAILED,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.FLParameter.TIME_OUT;
|
||||||
|
|
||||||
import okhttp3.Call;
|
import okhttp3.Call;
|
||||||
import okhttp3.Callback;
|
import okhttp3.Callback;
|
||||||
import okhttp3.MediaType;
|
import okhttp3.MediaType;
|
||||||
|
@ -24,12 +26,6 @@ import okhttp3.Request;
|
||||||
import okhttp3.RequestBody;
|
import okhttp3.RequestBody;
|
||||||
import okhttp3.Response;
|
import okhttp3.Response;
|
||||||
|
|
||||||
import javax.net.ssl.HostnameVerifier;
|
|
||||||
import javax.net.ssl.SSLContext;
|
|
||||||
import javax.net.ssl.SSLSession;
|
|
||||||
import javax.net.ssl.SSLSocketFactory;
|
|
||||||
import javax.net.ssl.TrustManager;
|
|
||||||
import javax.net.ssl.X509TrustManager;
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.security.KeyManagementException;
|
import java.security.KeyManagementException;
|
||||||
import java.security.NoSuchAlgorithmException;
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
@ -39,34 +35,40 @@ import java.util.concurrent.TimeUnit;
|
||||||
import java.util.concurrent.TimeoutException;
|
import java.util.concurrent.TimeoutException;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.FLParameter.TIME_OUT;
|
import javax.net.ssl.HostnameVerifier;
|
||||||
|
import javax.net.ssl.SSLContext;
|
||||||
|
import javax.net.ssl.SSLSession;
|
||||||
|
import javax.net.ssl.SSLSocketFactory;
|
||||||
|
import javax.net.ssl.TrustManager;
|
||||||
|
import javax.net.ssl.X509TrustManager;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define the communication interface.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class FLCommunication implements IFLCommunication {
|
public class FLCommunication implements IFLCommunication {
|
||||||
private static int timeOut;
|
private static int timeOut;
|
||||||
private static boolean ssl = false;
|
private static boolean ifCertificateVerify = false;
|
||||||
private static String env;
|
|
||||||
private static SSLSocketFactory sslSocketFactory;
|
|
||||||
private static X509TrustManager x509TrustManager;
|
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
|
||||||
private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8");
|
private static final MediaType MEDIA_TYPE_JSON = MediaType.parse("applicatiom/json;charset=utf-8");
|
||||||
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(FLCommunication.class.toString());
|
||||||
private OkHttpClient client;
|
|
||||||
|
|
||||||
private static volatile FLCommunication communication;
|
private static volatile FLCommunication communication;
|
||||||
|
|
||||||
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
|
private OkHttpClient client;
|
||||||
|
|
||||||
private FLCommunication() {
|
private FLCommunication() {
|
||||||
if (flParameter.getTimeOut() != 0) {
|
if (flParameter.getTimeOut() != 0) {
|
||||||
timeOut = flParameter.getTimeOut();
|
timeOut = flParameter.getTimeOut();
|
||||||
} else {
|
} else {
|
||||||
timeOut = TIME_OUT;
|
timeOut = TIME_OUT;
|
||||||
}
|
}
|
||||||
ssl = flParameter.isUseSSL();
|
ifCertificateVerify = flParameter.isUseSSL();
|
||||||
client = getOkHttpClient();
|
client = getOkHttpClient();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static OkHttpClient getOkHttpClient() {
|
private static OkHttpClient getOkHttpClient() {
|
||||||
X509TrustManager trustManager = new X509TrustManager() {
|
X509TrustManager trustManager = new X509TrustManager() {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public X509Certificate[] getAcceptedIssuers() {
|
public X509Certificate[] getAcceptedIssuers() {
|
||||||
return new X509Certificate[]{};
|
return new X509Certificate[]{};
|
||||||
|
@ -89,14 +91,15 @@ public class FLCommunication implements IFLCommunication {
|
||||||
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
|
builder.connectTimeout(timeOut, TimeUnit.SECONDS);
|
||||||
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
|
builder.writeTimeout(timeOut, TimeUnit.SECONDS);
|
||||||
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
|
builder.readTimeout(3 * timeOut, TimeUnit.SECONDS);
|
||||||
if (ssl) {
|
if (ifCertificateVerify) {
|
||||||
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(), SSLSocketFactoryTools.getInstance().getmTrustManager());
|
builder.sslSocketFactory(SSLSocketFactoryTools.getInstance().getmSslSocketFactory(),
|
||||||
|
SSLSocketFactoryTools.getInstance().getmTrustManager());
|
||||||
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
|
builder.hostnameVerifier(SSLSocketFactoryTools.getInstance().getHostnameVerifier());
|
||||||
} else {
|
} else {
|
||||||
final SSLContext sslContext = SSLContext.getInstance("TLS");
|
final SSLContext sslContext = SSLContext.getInstance("TLS");
|
||||||
sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
|
sslContext.init(null, trustAllCerts, Common.getSecureRandom());
|
||||||
final javax.net.ssl.SSLSocketFactory sslSocketFactory = sslContext.getSocketFactory();
|
final SSLSocketFactory sslFactory = sslContext.getSocketFactory();
|
||||||
builder.sslSocketFactory(sslSocketFactory, trustManager);
|
builder.sslSocketFactory(sslFactory, trustManager);
|
||||||
builder.hostnameVerifier(new HostnameVerifier() {
|
builder.hostnameVerifier(new HostnameVerifier() {
|
||||||
@Override
|
@Override
|
||||||
public boolean verify(String arg0, SSLSession arg1) {
|
public boolean verify(String arg0, SSLSession arg1) {
|
||||||
|
@ -104,14 +107,18 @@ public class FLCommunication implements IFLCommunication {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return builder.build();
|
return builder.build();
|
||||||
} catch (NoSuchAlgorithmException | KeyManagementException e) {
|
} catch (NoSuchAlgorithmException | KeyManagementException ex) {
|
||||||
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + e.getMessage()));
|
LOGGER.severe(Common.addTag("[OkHttpClient] catch NoSuchAlgorithmException or KeyManagementException: " + ex.getMessage()));
|
||||||
throw new RuntimeException(e);
|
throw new IllegalArgumentException(ex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the singleton object of the class FLCommunication.
|
||||||
|
*
|
||||||
|
* @return the singleton object of the class FLCommunication.
|
||||||
|
*/
|
||||||
public static FLCommunication getInstance() {
|
public static FLCommunication getInstance() {
|
||||||
FLCommunication localRef = communication;
|
FLCommunication localRef = communication;
|
||||||
if (localRef == null) {
|
if (localRef == null) {
|
||||||
|
@ -138,6 +145,9 @@ public class FLCommunication implements IFLCommunication {
|
||||||
if (!response.isSuccessful()) {
|
if (!response.isSuccessful()) {
|
||||||
throw new IOException("Unexpected code " + response);
|
throw new IOException("Unexpected code " + response);
|
||||||
}
|
}
|
||||||
|
if (response.body() == null) {
|
||||||
|
throw new IOException("the returned response is null");
|
||||||
|
}
|
||||||
return response.body().bytes();
|
return response.body().bytes();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,11 +169,10 @@ public class FLCommunication implements IFLCommunication {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onFailure(Call call, IOException e) {
|
public void onFailure(Call call, IOException ioException) {
|
||||||
asyncCallBack.onFailure(e);
|
asyncCallBack.onFailure(ioException);
|
||||||
call.cancel();
|
call.cancel();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,19 +13,40 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define job result callback function.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class FLJobResultCallback implements IFLJobResultCallback {
|
public class FLJobResultCallback implements IFLJobResultCallback {
|
||||||
private static final Logger LOGGER = Logger.getLogger(FLJobResultCallback.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(FLJobResultCallback.class.toString());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called at the end of an iteration for Fl job
|
||||||
|
*
|
||||||
|
* @param modelName the name of model
|
||||||
|
* @param iterationSeq Iteration number
|
||||||
|
* @param resultCode Status Code
|
||||||
|
*/
|
||||||
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) {
|
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode) {
|
||||||
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " + iterationSeq + " resultCode: " + resultCode));
|
LOGGER.info(Common.addTag("[onFlJobIterationFinished] modelName: " + modelName + " iterationSeq: " +
|
||||||
|
iterationSeq + " resultCode: " + resultCode));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called on completion for Fl job
|
||||||
|
*
|
||||||
|
* @param modelName the name of model
|
||||||
|
* @param iterationCount total Iteration numbers
|
||||||
|
* @param resultCode Status Code
|
||||||
|
*/
|
||||||
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
|
public void onFlJobFinished(String modelName, int iterationCount, int resultCode) {
|
||||||
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " + iterationCount + " resultCode: " + resultCode));
|
LOGGER.info(Common.addTag("[onFlJobFinished] modelName: " + modelName + " iterationCount: " +
|
||||||
|
iterationCount + " resultCode: " + resultCode));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,11 +16,15 @@
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
import com.mindspore.flclient.cipher.BaseUtil;
|
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||||
|
|
||||||
import com.mindspore.flclient.model.AlInferBert;
|
import com.mindspore.flclient.model.AlInferBert;
|
||||||
import com.mindspore.flclient.model.AlTrainBert;
|
import com.mindspore.flclient.model.AlTrainBert;
|
||||||
import com.mindspore.flclient.model.SessionUtil;
|
import com.mindspore.flclient.model.SessionUtil;
|
||||||
import com.mindspore.flclient.model.TrainLenet;
|
import com.mindspore.flclient.model.TrainLenet;
|
||||||
|
|
||||||
import mindspore.schema.CipherPublicParams;
|
import mindspore.schema.CipherPublicParams;
|
||||||
import mindspore.schema.FLPlan;
|
import mindspore.schema.FLPlan;
|
||||||
import mindspore.schema.ResponseCode;
|
import mindspore.schema.ResponseCode;
|
||||||
|
@ -36,18 +40,21 @@ import java.util.Map;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
/**
|
||||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
* Defining the general process of federated learning tasks.
|
||||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class FLLiteClient {
|
public class FLLiteClient {
|
||||||
private static final Logger LOGGER = Logger.getLogger(FLLiteClient.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(FLLiteClient.class.toString());
|
||||||
private FLCommunication flCommunication;
|
private static int iteration = 0;
|
||||||
|
private static Map<String, float[]> mapBeforeTrain;
|
||||||
|
|
||||||
|
private double dpNormClipFactor = 1.0d;
|
||||||
|
private double dpNormClipAdapt = 0.05d;
|
||||||
|
private FLCommunication flCommunication;
|
||||||
private FLClientStatus status;
|
private FLClientStatus status;
|
||||||
private int retCode;
|
private int retCode;
|
||||||
|
|
||||||
private static int iteration = 0;
|
|
||||||
private int iterations = 1;
|
private int iterations = 1;
|
||||||
private int epochs = 1;
|
private int epochs = 1;
|
||||||
private int batchSize = 16;
|
private int batchSize = 16;
|
||||||
|
@ -55,22 +62,21 @@ public class FLLiteClient {
|
||||||
private byte[] prime;
|
private byte[] prime;
|
||||||
private int featureSize;
|
private int featureSize;
|
||||||
private int trainDataSize;
|
private int trainDataSize;
|
||||||
private double dpEps = 100;
|
private double dpEps = 100d;
|
||||||
private double dpDelta = 0.01;
|
private double dpDelta = 0.01d;
|
||||||
public double dpNormClipFactor = 1.0;
|
|
||||||
public double dpNormClipAdapt = 0.05;
|
|
||||||
|
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||||
private SecureProtocol secureProtocol = new SecureProtocol();
|
private SecureProtocol secureProtocol = new SecureProtocol();
|
||||||
private static Map<String, float[]> mapBeforeTrain;
|
|
||||||
private String nextRequestTime;
|
private String nextRequestTime;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defining a constructor of teh class FLLiteClient.
|
||||||
|
*/
|
||||||
public FLLiteClient() {
|
public FLLiteClient() {
|
||||||
flCommunication = FLCommunication.getInstance();
|
flCommunication = FLCommunication.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
public int setGlobalParameters(ResponseFLJob flJob) {
|
private int setGlobalParameters(ResponseFLJob flJob) {
|
||||||
FLPlan flPlan = flJob.flPlanConfig();
|
FLPlan flPlan = flJob.flPlanConfig();
|
||||||
if (flPlan == null) {
|
if (flPlan == null) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] the FLPlan get from server is null"));
|
LOGGER.severe(Common.addTag("[startFLJob] the FLPlan get from server is null"));
|
||||||
|
@ -90,14 +96,22 @@ public class FLLiteClient {
|
||||||
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
|
LOGGER.info(Common.addTag("[startFLJob] set <batchSize> for TrainLenet: " + batchSize));
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
trainLenet.setBatchSize(batchSize);
|
trainLenet.setBatchSize(batchSize);
|
||||||
|
} else {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the ServerMod returned from server is not valid"));
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <epochs> from server: " + epochs));
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <batchSize> from server: " + batchSize));
|
||||||
CipherPublicParams cipherPublicParams = flPlan.cipher();
|
CipherPublicParams cipherPublicParams = flPlan.cipher();
|
||||||
|
if (cipherPublicParams == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the cipherPublicParams returned from server is null"));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
String encryptLevel = cipherPublicParams.encryptType();
|
String encryptLevel = cipherPublicParams.encryptType();
|
||||||
if ("".equals(encryptLevel) || encryptLevel.isEmpty()) {
|
if (encryptLevel == null || encryptLevel.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server is null, set the encryptLevel to NOT_ENCRYPT "));
|
LOGGER.severe(Common.addTag("[startFLJob] GlobalParameters <encryptLevel> from server is null, set the " +
|
||||||
|
"encryptLevel to NOT_ENCRYPT "));
|
||||||
localFLParameter.setEncryptLevel(EncryptLevel.NOT_ENCRYPT.toString());
|
localFLParameter.setEncryptLevel(EncryptLevel.NOT_ENCRYPT.toString());
|
||||||
} else {
|
} else {
|
||||||
localFLParameter.setEncryptLevel(encryptLevel);
|
localFLParameter.setEncryptLevel(encryptLevel);
|
||||||
|
@ -113,10 +127,10 @@ public class FLLiteClient {
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server: " + minSecretNum));
|
||||||
if (minSecretNum <= 0) {
|
if (minSecretNum <= 0) {
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid: <=0"));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <minSecretNum> from server is not valid:" +
|
||||||
|
" <=0"));
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[Encrypt] the prime from server: " + BaseUtil.byte2HexString(prime)));
|
|
||||||
break;
|
break;
|
||||||
case DP_ENCRYPT:
|
case DP_ENCRYPT:
|
||||||
dpEps = cipherPublicParams.dpEps();
|
dpEps = cipherPublicParams.dpEps();
|
||||||
|
@ -124,53 +138,97 @@ public class FLLiteClient {
|
||||||
dpNormClipFactor = cipherPublicParams.dpNormClip();
|
dpNormClipFactor = cipherPublicParams.dpNormClip();
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpEps> from server: " + dpEps));
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpDelta> from server: " + dpDelta));
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " + dpNormClipFactor));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <dpNormClipFactor> from server: " +
|
||||||
|
dpNormClipFactor));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOGGER.info(Common.addTag("[startFLJob] NotEncrypt, do not set parameter for Encrypt"));
|
LOGGER.info(Common.addTag("[startFLJob] NOT_ENCRYPT, do not set parameter for Encrypt"));
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain retCode returned by server.
|
||||||
|
*
|
||||||
|
* @return the retCode returned by server.
|
||||||
|
*/
|
||||||
public int getRetCode() {
|
public int getRetCode() {
|
||||||
return retCode;
|
return retCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain current iteration returned by server.
|
||||||
|
*
|
||||||
|
* @return the current iteration returned by server.
|
||||||
|
*/
|
||||||
public int getIteration() {
|
public int getIteration() {
|
||||||
return iteration;
|
return iteration;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain total iterations for the task returned by server.
|
||||||
|
*
|
||||||
|
* @return the total iterations for the task returned by server.
|
||||||
|
*/
|
||||||
public int getIterations() {
|
public int getIterations() {
|
||||||
return iterations;
|
return iterations;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getEpochs() {
|
/**
|
||||||
return epochs;
|
* Obtain the returned timestamp for next request from server.
|
||||||
}
|
*
|
||||||
|
* @return the timestamp for next request.
|
||||||
public int getBatchSize() {
|
*/
|
||||||
return batchSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getNextRequestTime() {
|
public String getNextRequestTime() {
|
||||||
return nextRequestTime;
|
return nextRequestTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setNextRequestTime(String nextRequestTime) {
|
/**
|
||||||
this.nextRequestTime = nextRequestTime;
|
* set the size of train date set.
|
||||||
}
|
*
|
||||||
|
* @param trainDataSize the size of train date set.
|
||||||
|
*/
|
||||||
public void setTrainDataSize(int trainDataSize) {
|
public void setTrainDataSize(int trainDataSize) {
|
||||||
this.trainDataSize = trainDataSize;
|
this.trainDataSize = trainDataSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus checkStatus() {
|
/**
|
||||||
return this.status;
|
* Obtain the dpNormClipFactor.
|
||||||
|
*
|
||||||
|
* @return the dpNormClipFactor.
|
||||||
|
*/
|
||||||
|
public double getDpNormClipFactor() {
|
||||||
|
return dpNormClipFactor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain the dpNormClipAdapt.
|
||||||
|
*
|
||||||
|
* @return the dpNormClipAdapt.
|
||||||
|
*/
|
||||||
|
public double getDpNormClipAdapt() {
|
||||||
|
return dpNormClipAdapt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the dpNormClipAdapt.
|
||||||
|
*
|
||||||
|
* @param dpNormClipAdapt the dpNormClipAdapt.
|
||||||
|
*/
|
||||||
|
public void setDpNormClipAdapt(double dpNormClipAdapt) {
|
||||||
|
this.dpNormClipAdapt = dpNormClipAdapt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send serialized request message of startFLJob to server.
|
||||||
|
*
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus startFLJob() {
|
public FLClientStatus startFLJob() {
|
||||||
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify server===================================="));
|
LOGGER.info(Common.addTag("[startFLJob] ====================================Verify " +
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
"server===================================="));
|
||||||
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||||
|
flParameter.getDomainName());
|
||||||
StartFLJob startFLJob = StartFLJob.getInstance();
|
StartFLJob startFLJob = StartFLJob.getInstance();
|
||||||
Date date = new Date();
|
Date date = new Date();
|
||||||
long time = date.getTime();
|
long time = date.getTime();
|
||||||
|
@ -179,8 +237,9 @@ public class FLLiteClient {
|
||||||
long start = Common.startTime("single startFLJob");
|
long start = Common.startTime("single startFLJob");
|
||||||
LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length));
|
LOGGER.info(Common.addTag("[startFLJob] the request message length: " + msg.length));
|
||||||
byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg);
|
byte[] message = flCommunication.syncRequest(url + "/startFLJob", msg);
|
||||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(message)) {
|
||||||
LOGGER.info(Common.addTag("[startFLJob] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[startFLJob] the server is not ready now, need wait some time and request " +
|
||||||
|
"again"));
|
||||||
status = FLClientStatus.RESTART;
|
status = FLClientStatus.RESTART;
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
nextRequestTime = "";
|
nextRequestTime = "";
|
||||||
|
@ -193,14 +252,15 @@ public class FLLiteClient {
|
||||||
status = judgeStartFLJob(startFLJob, responseDataBuf);
|
status = judgeStartFLJob(startFLJob, responseDataBuf);
|
||||||
retCode = responseDataBuf.retcode();
|
retCode = responseDataBuf.retcode();
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " + e.getMessage()));
|
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in StartFLJob: catch IOException: " +
|
||||||
|
e.getMessage()));
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
|
private FLClientStatus judgeStartFLJob(StartFLJob startFLJob, ResponseFLJob responseDataBuf) {
|
||||||
iteration = responseDataBuf.iteration();
|
iteration = responseDataBuf.iteration();
|
||||||
FLClientStatus response = startFLJob.doResponse(responseDataBuf);
|
FLClientStatus response = startFLJob.doResponse(responseDataBuf);
|
||||||
status = response;
|
status = response;
|
||||||
|
@ -218,6 +278,10 @@ public class FLLiteClient {
|
||||||
break;
|
break;
|
||||||
case RESTART:
|
case RESTART:
|
||||||
FLPlan flPlan = responseDataBuf.flPlanConfig();
|
FLPlan flPlan = responseDataBuf.flPlanConfig();
|
||||||
|
if (flPlan == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the flPlan returned from server is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
iterations = flPlan.iterations();
|
iterations = flPlan.iterations();
|
||||||
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
|
LOGGER.info(Common.addTag("[startFLJob] GlobalParameters <iterations> from server: " + iterations));
|
||||||
nextRequestTime = responseDataBuf.nextReqTime();
|
nextRequestTime = responseDataBuf.nextReqTime();
|
||||||
|
@ -226,14 +290,21 @@ public class FLLiteClient {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] startFLJob failed"));
|
LOGGER.severe(Common.addTag("[startFLJob] startFLJob failed"));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range <SUCCESS, WAIT, FAILED, Restart>"));
|
LOGGER.severe(Common.addTag("[startFLJob] failed: the response of startFLJob is out of range " +
|
||||||
|
"<SUCCESS, WAIT, FAILED, Restart>"));
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define the training process.
|
||||||
|
*
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus localTrain() {
|
public FLClientStatus localTrain() {
|
||||||
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration + "===================================="));
|
LOGGER.info(Common.addTag("[train] ====================================global train epoch " + iteration +
|
||||||
|
"===================================="));
|
||||||
status = FLClientStatus.SUCCESS;
|
status = FLClientStatus.SUCCESS;
|
||||||
retCode = ResponseCode.SUCCEED;
|
retCode = ResponseCode.SUCCEED;
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
|
@ -254,12 +325,22 @@ public class FLLiteClient {
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
LOGGER.severe(Common.addTag("[train] the flName is not valid"));
|
||||||
|
status = FLClientStatus.FAILED;
|
||||||
|
retCode = ResponseCode.RequestError;
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send serialized request message of updateModel to server.
|
||||||
|
*
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus updateModel() {
|
public FLClientStatus updateModel() {
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||||
|
flParameter.getDomainName());
|
||||||
UpdateModel updateModelBuf = UpdateModel.getInstance();
|
UpdateModel updateModelBuf = UpdateModel.getInstance();
|
||||||
byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize);
|
byte[] updateModelBuffer = updateModelBuf.getRequestUpdateFLJob(iteration, secureProtocol, trainDataSize);
|
||||||
if (updateModelBuf.getStatus() == FLClientStatus.FAILED) {
|
if (updateModelBuf.getStatus() == FLClientStatus.FAILED) {
|
||||||
|
@ -270,8 +351,9 @@ public class FLLiteClient {
|
||||||
long start = Common.startTime("single updateModel");
|
long start = Common.startTime("single updateModel");
|
||||||
LOGGER.info(Common.addTag("[updateModel] the request message length: " + updateModelBuffer.length));
|
LOGGER.info(Common.addTag("[updateModel] the request message length: " + updateModelBuffer.length));
|
||||||
byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
|
byte[] message = flCommunication.syncRequest(url + "/updateModel", updateModelBuffer);
|
||||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(message)) {
|
||||||
LOGGER.info(Common.addTag("[updateModel] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[updateModel] the server is not ready now, need wait some time and request" +
|
||||||
|
" again"));
|
||||||
status = FLClientStatus.RESTART;
|
status = FLClientStatus.RESTART;
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
nextRequestTime = "";
|
nextRequestTime = "";
|
||||||
|
@ -288,23 +370,31 @@ public class FLLiteClient {
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[updateModel] get response from server ok!"));
|
LOGGER.info(Common.addTag("[updateModel] get response from server ok!"));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " + e.getMessage()));
|
LOGGER.severe(Common.addTag("[updateModel] unsolved error code in updateModel: catch IOException: " +
|
||||||
|
e.getMessage()));
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Send serialized request message of getModel to server.
|
||||||
|
*
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus getModel() {
|
public FLClientStatus getModel() {
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||||
|
flParameter.getDomainName());
|
||||||
GetModel getModelBuf = GetModel.getInstance();
|
GetModel getModelBuf = GetModel.getInstance();
|
||||||
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration);
|
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), iteration);
|
||||||
try {
|
try {
|
||||||
long start = Common.startTime("single getModel");
|
long start = Common.startTime("single getModel");
|
||||||
LOGGER.info(Common.addTag("[getModel] the request message length: " + buffer.length));
|
LOGGER.info(Common.addTag("[getModel] the request message length: " + buffer.length));
|
||||||
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
||||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(message)) {
|
||||||
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
|
||||||
|
"again"));
|
||||||
status = FLClientStatus.WAIT;
|
status = FLClientStatus.WAIT;
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
@ -327,6 +417,12 @@ public class FLLiteClient {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain the weight of the model before training.
|
||||||
|
*
|
||||||
|
* @param map a map to store the weight of the model.
|
||||||
|
* @return the weight.
|
||||||
|
*/
|
||||||
public static synchronized Map<String, float[]> getOldMapCopy(Map<String, float[]> map) {
|
public static synchronized Map<String, float[]> getOldMapCopy(Map<String, float[]> map) {
|
||||||
if (mapBeforeTrain == null) {
|
if (mapBeforeTrain == null) {
|
||||||
Map<String, float[]> copyMap = new TreeMap<>();
|
Map<String, float[]> copyMap = new TreeMap<>();
|
||||||
|
@ -334,7 +430,8 @@ public class FLLiteClient {
|
||||||
float[] data = map.get(key);
|
float[] data = map.get(key);
|
||||||
int dataLen = data.length;
|
int dataLen = data.length;
|
||||||
float[] weights = new float[dataLen];
|
float[] weights = new float[dataLen];
|
||||||
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
|
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) &&
|
||||||
|
(key.indexOf("learning") < 0)) {
|
||||||
for (int j = 0; j < dataLen; j++) {
|
for (int j = 0; j < dataLen; j++) {
|
||||||
float weight = data[j];
|
float weight = data[j];
|
||||||
weights[j] = weight;
|
weights[j] = weight;
|
||||||
|
@ -348,7 +445,8 @@ public class FLLiteClient {
|
||||||
float[] data = map.get(key);
|
float[] data = map.get(key);
|
||||||
float[] copyData = mapBeforeTrain.get(key);
|
float[] copyData = mapBeforeTrain.get(key);
|
||||||
int dataLen = data.length;
|
int dataLen = data.length;
|
||||||
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) && (key.indexOf("learning") < 0)) {
|
if ((key.indexOf("Default") < 0) && (key.indexOf("nhwc") < 0) && (key.indexOf("moment") < 0) &&
|
||||||
|
(key.indexOf("learning") < 0)) {
|
||||||
for (int j = 0; j < dataLen; j++) {
|
for (int j = 0; j < dataLen; j++) {
|
||||||
copyData[j] = data[j];
|
copyData[j] = data[j];
|
||||||
}
|
}
|
||||||
|
@ -358,18 +456,25 @@ public class FLLiteClient {
|
||||||
return mapBeforeTrain;
|
return mapBeforeTrain;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain pairwise mask and individual mask.
|
||||||
|
*
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus getFeatureMask() {
|
public FLClientStatus getFeatureMask() {
|
||||||
FLClientStatus curStatus;
|
FLClientStatus curStatus;
|
||||||
switch (localFLParameter.getEncryptLevel()) {
|
switch (localFLParameter.getEncryptLevel()) {
|
||||||
case PW_ENCRYPT:
|
case PW_ENCRYPT:
|
||||||
LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" + localFLParameter.getEncryptLevel().toString() + ">"));
|
LOGGER.info(Common.addTag("[Encrypt] creating feature mask of <" +
|
||||||
|
localFLParameter.getEncryptLevel().toString() + ">"));
|
||||||
secureProtocol.setPWParameter(iteration, minSecretNum, prime, featureSize);
|
secureProtocol.setPWParameter(iteration, minSecretNum, prime, featureSize);
|
||||||
curStatus = secureProtocol.pwCreateMask();
|
curStatus = secureProtocol.pwCreateMask();
|
||||||
if (curStatus == FLClientStatus.RESTART) {
|
if (curStatus == FLClientStatus.RESTART) {
|
||||||
nextRequestTime = secureProtocol.getNextRequestTime();
|
nextRequestTime = secureProtocol.getNextRequestTime();
|
||||||
}
|
}
|
||||||
retCode = secureProtocol.getRetCode();
|
retCode = secureProtocol.getRetCode();
|
||||||
LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" + localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
|
LOGGER.info(Common.addTag("[Encrypt] the response of create mask for <" +
|
||||||
|
localFLParameter.getEncryptLevel().toString() + "> : " + curStatus));
|
||||||
return curStatus;
|
return curStatus;
|
||||||
case DP_ENCRYPT:
|
case DP_ENCRYPT:
|
||||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||||
|
@ -388,7 +493,7 @@ public class FLLiteClient {
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[Encrypt] set parameters for DPEncrypt!"));
|
LOGGER.info(Common.addTag("[Encrypt] set parameters for DP_ENCRYPT!"));
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
case NOT_ENCRYPT:
|
case NOT_ENCRYPT:
|
||||||
retCode = ResponseCode.SUCCEED;
|
retCode = ResponseCode.SUCCEED;
|
||||||
|
@ -401,6 +506,11 @@ public class FLLiteClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reconstruct the secrets used for unmasking model weights.
|
||||||
|
*
|
||||||
|
* @return current status code in client.
|
||||||
|
*/
|
||||||
public FLClientStatus unMasking() {
|
public FLClientStatus unMasking() {
|
||||||
FLClientStatus curStatus;
|
FLClientStatus curStatus;
|
||||||
switch (localFLParameter.getEncryptLevel()) {
|
switch (localFLParameter.getEncryptLevel()) {
|
||||||
|
@ -413,7 +523,7 @@ public class FLLiteClient {
|
||||||
}
|
}
|
||||||
return curStatus;
|
return curStatus;
|
||||||
case DP_ENCRYPT:
|
case DP_ENCRYPT:
|
||||||
LOGGER.info(Common.addTag("[Encrypt] DPEncrypt do not need unmasking"));
|
LOGGER.info(Common.addTag("[Encrypt] DP_ENCRYPT do not need unmasking"));
|
||||||
retCode = ResponseCode.SUCCEED;
|
retCode = ResponseCode.SUCCEED;
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
case NOT_ENCRYPT:
|
case NOT_ENCRYPT:
|
||||||
|
@ -427,18 +537,26 @@ public class FLLiteClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluate model after getting model from server.
|
||||||
|
*
|
||||||
|
* @return the status code in client.
|
||||||
|
*/
|
||||||
public FLClientStatus evaluateModel() {
|
public FLClientStatus evaluateModel() {
|
||||||
status = FLClientStatus.SUCCESS;
|
status = FLClientStatus.SUCCESS;
|
||||||
retCode = ResponseCode.SUCCEED;
|
retCode = ResponseCode.SUCCEED;
|
||||||
LOGGER.info(Common.addTag("===================================evaluate model after getting model from server==================================="));
|
LOGGER.info(Common.addTag("===================================evaluate model after getting model from " +
|
||||||
|
"server==================================="));
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
float acc = 0;
|
float acc = 0;
|
||||||
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
if (localFLParameter.getServerMod().equals(ServerMod.HYBRID_TRAINING.toString())) {
|
||||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||||
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile(), true);
|
int dataSize = alInferBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
|
||||||
|
flParameter.getIdsFile(), true);
|
||||||
if (dataSize <= 0) {
|
if (dataSize <= 0) {
|
||||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the return dataSize<=0"));
|
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alInferBert.initDataSet>: the " +
|
||||||
|
"return dataSize<=0"));
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
return status;
|
return status;
|
||||||
|
@ -447,47 +565,66 @@ public class FLLiteClient {
|
||||||
} else {
|
} else {
|
||||||
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
LOGGER.info(Common.addTag("[evaluate] evaluateModel by " + localFLParameter.getServerMod()));
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile());
|
int dataSize = alTrainBert.initDataSet(flParameter.getTestDataset(), flParameter.getVocabFile(),
|
||||||
|
flParameter.getIdsFile());
|
||||||
if (dataSize <= 0) {
|
if (dataSize <= 0) {
|
||||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the return dataSize<=0"));
|
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <alTrainBert.initDataSet>: the " +
|
||||||
|
"return dataSize<=0"));
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
acc = alTrainBert.evalModel();
|
acc = alTrainBert.evalModel();
|
||||||
}
|
}
|
||||||
if (acc == Float.NaN) {
|
if (Float.isNaN(acc)) {
|
||||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <evalModel>: the return acc is NAN"));
|
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <evalModel>: the return acc is NAN"));
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
|
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
|
||||||
|
flParameter.getTestDataset() + " vocabFile: " + flParameter.getVocabFile() +
|
||||||
|
" idsFile: " + flParameter.getIdsFile()));
|
||||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||||
} else if (flParameter.getFlName().equals(LENET)) {
|
} else if (flParameter.getFlName().equals(LENET)) {
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0], flParameter.getTestDataset().split(",")[1]);
|
if (flParameter.getTestDataset().split(",").length < 2) {
|
||||||
|
LOGGER.severe(Common.addTag("[evaluate] the set testDataPath for lenet is not valid, should be the " +
|
||||||
|
"format of <data.bin,label.bin> "));
|
||||||
|
status = FLClientStatus.FAILED;
|
||||||
|
retCode = ResponseCode.RequestError;
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
int dataSize = trainLenet.initDataSet(flParameter.getTestDataset().split(",")[0],
|
||||||
|
flParameter.getTestDataset().split(",")[1]);
|
||||||
if (dataSize <= 0) {
|
if (dataSize <= 0) {
|
||||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return dataSize<=0"));
|
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.initDataSet>: the return " +
|
||||||
|
"dataSize<=0"));
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
float acc = trainLenet.evalModel();
|
float acc = trainLenet.evalModel();
|
||||||
if (acc == Float.NaN) {
|
if (Float.isNaN(acc)) {
|
||||||
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc is NAN"));
|
LOGGER.severe(Common.addTag("[evaluate] unsolved error code in <trainLenet.evalModel>: the return acc" +
|
||||||
|
" is NAN"));
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " + flParameter.getTestDataset().split(",")[0] + " labelPath: " + flParameter.getTestDataset().split(",")[1]));
|
LOGGER.info(Common.addTag("[evaluate] modelPath: " + flParameter.getInferModelPath() + " dataPath: " +
|
||||||
|
flParameter.getTestDataset().split(",")[0] + " labelPath: " +
|
||||||
|
flParameter.getTestDataset().split(",")[1]));
|
||||||
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
LOGGER.info(Common.addTag("[evaluate] evaluate acc: " + acc));
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param dataPath, train or test dataset and label set
|
* Set date path.
|
||||||
|
*
|
||||||
|
* @param dataPath, train or test dataset and label set.
|
||||||
|
* @return date size.
|
||||||
*/
|
*/
|
||||||
public int setInput(String dataPath) {
|
public int setInput(String dataPath) {
|
||||||
retCode = ResponseCode.SUCCEED;
|
retCode = ResponseCode.SUCCEED;
|
||||||
|
@ -496,15 +633,18 @@ public class FLLiteClient {
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
dataSize = alTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
|
dataSize = alTrainBert.initDataSet(dataPath, flParameter.getVocabFile(), flParameter.getIdsFile());
|
||||||
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
|
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath + " dataSize: " + +dataSize + " " +
|
||||||
|
"vocabFile: " + flParameter.getVocabFile() + " idsFile: " + flParameter.getIdsFile()));
|
||||||
} else if (flParameter.getFlName().equals(LENET)) {
|
} else if (flParameter.getFlName().equals(LENET)) {
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
if (dataPath.split(",").length < 2) {
|
if (dataPath.split(",").length < 2) {
|
||||||
LOGGER.info(Common.addTag("[set input] the set dataPath for lenet is not valid, should be the format of <data.bin,label.bin>"));
|
LOGGER.severe(Common.addTag("[set input] the set dataPath for lenet is not valid, should be the " +
|
||||||
|
"format of <data.bin,label.bin> "));
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]);
|
dataSize = trainLenet.initDataSet(dataPath.split(",")[0], dataPath.split(",")[1]);
|
||||||
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " + +dataSize + " labelPath: " + dataPath.split(",")[1]));
|
LOGGER.info(Common.addTag("[set input] " + "dataPath: " + dataPath.split(",")[0] + " dataSize: " +
|
||||||
|
dataSize + " labelPath: " + dataPath.split(",")[1]));
|
||||||
}
|
}
|
||||||
if (dataSize <= 0) {
|
if (dataSize <= 0) {
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
|
@ -513,36 +653,48 @@ public class FLLiteClient {
|
||||||
return dataSize;
|
return dataSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialization session.
|
||||||
|
*
|
||||||
|
* @return the status code in client.
|
||||||
|
*/
|
||||||
public FLClientStatus initSession() {
|
public FLClientStatus initSession() {
|
||||||
int tag = 0;
|
int tag = 0;
|
||||||
retCode = ResponseCode.SUCCEED;
|
retCode = ResponseCode.SUCCEED;
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
|
||||||
|
"Train Session============="));
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
|
||||||
|
"is -1"));
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
|
LOGGER.info(Common.addTag("==========Loading inference model, " + flParameter.getInferModelPath() + " " +
|
||||||
|
"Create inference Session============="));
|
||||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||||
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||||
} else if (flParameter.getFlName().equals(LENET)) {
|
} else if (flParameter.getFlName().equals(LENET)) {
|
||||||
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
LOGGER.info(Common.addTag("==========Loading train model, " + flParameter.getTrainModelPath() + " Create " +
|
||||||
|
"Train Session============="));
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||||
}
|
}
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is " +
|
||||||
|
"-1"));
|
||||||
retCode = ResponseCode.RequestError;
|
retCode = ResponseCode.RequestError;
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
/**
|
||||||
protected void finalize() {
|
* Free session.
|
||||||
|
*/
|
||||||
|
protected void freeSession() {
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
LOGGER.info(Common.addTag("===========free train session============="));
|
LOGGER.info(Common.addTag("===========free train session============="));
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
|
@ -558,5 +710,4 @@ public class FLLiteClient {
|
||||||
SessionUtil.free(trainLenet.getTrainSession());
|
SessionUtil.free(trainLenet.getTrainSession());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,22 +13,36 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
package com.mindspore.flclient;
|
|
||||||
|
|
||||||
import java.util.logging.Logger;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defines global parameters used during federated learning and these parameters are provided for users to set.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class FLParameter {
|
public class FLParameter {
|
||||||
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(FLParameter.class.toString());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The timeout interval for communication on the device.
|
||||||
|
*/
|
||||||
public static final int TIME_OUT = 100;
|
public static final int TIME_OUT = 100;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The waiting time of repeated requests.
|
||||||
|
*/
|
||||||
public static final int SLEEP_TIME = 1000;
|
public static final int SLEEP_TIME = 1000;
|
||||||
|
private static volatile FLParameter flParameter;
|
||||||
|
|
||||||
private String hostName;
|
private String domainName;
|
||||||
private String certPath;
|
private String certPath;
|
||||||
private boolean useHttps = false;
|
|
||||||
|
|
||||||
private String trainDataset;
|
private String trainDataset;
|
||||||
private String vocabFile = "null";
|
private String vocabFile = "null";
|
||||||
private String idsFile = "null";
|
private String idsFile = "null";
|
||||||
|
@ -37,22 +51,21 @@ public class FLParameter {
|
||||||
private String trainModelPath;
|
private String trainModelPath;
|
||||||
private String inferModelPath;
|
private String inferModelPath;
|
||||||
private String clientID;
|
private String clientID;
|
||||||
private String ip;
|
|
||||||
private int port;
|
|
||||||
private boolean useSSL = false;
|
private boolean useSSL = false;
|
||||||
private int timeOut;
|
private int timeOut;
|
||||||
private int sleepTime;
|
private int sleepTime;
|
||||||
private boolean useElb = false;
|
private boolean ifUseElb = false;
|
||||||
private int serverNum = 1;
|
private int serverNum = 1;
|
||||||
|
|
||||||
private boolean timer = true;
|
private FLParameter() {
|
||||||
private int timeWindow = 6000;
|
clientID = UUID.randomUUID().toString();
|
||||||
private int reRequestNum = timeWindow / SLEEP_TIME + 1;
|
}
|
||||||
|
|
||||||
private static volatile FLParameter flParameter;
|
|
||||||
|
|
||||||
private FLParameter() {}
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the singleton object of the class FLParameter.
|
||||||
|
*
|
||||||
|
* @return the singleton object of the class FLParameter.
|
||||||
|
*/
|
||||||
public static FLParameter getInstance() {
|
public static FLParameter getInstance() {
|
||||||
FLParameter localRef = flParameter;
|
FLParameter localRef = flParameter;
|
||||||
if (localRef == null) {
|
if (localRef == null) {
|
||||||
|
@ -66,95 +79,100 @@ public class FLParameter {
|
||||||
return localRef;
|
return localRef;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getHostName() {
|
public String getDomainName() {
|
||||||
if ("".equals(hostName) || hostName.isEmpty()) {
|
if (domainName == null || domainName.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <hostName> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is null or empty, please set it " +
|
||||||
throw new RuntimeException();
|
"before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return hostName;
|
return domainName;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setHostName(String hostName) {
|
public void setDomainName(String domainName) {
|
||||||
this.hostName = hostName;
|
if (domainName == null || domainName.isEmpty() || (!("https".equals(domainName.split(":")[0]) || "http".equals(domainName.split(":")[0])))) {
|
||||||
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <domainName> is not valid, it should be like " +
|
||||||
|
"as https://...... or http://......, please check it before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.domainName = domainName;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getCertPath() {
|
public String getCertPath() {
|
||||||
if ("".equals(certPath) || certPath.isEmpty()) {
|
if (certPath == null || certPath.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is null or empty, please set it " +
|
||||||
throw new RuntimeException();
|
"before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return certPath;
|
return certPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setCertPath(String certPath) {
|
public void setCertPath(String certPath) {
|
||||||
certPath = Common.getRealPath(certPath);
|
String realCertPath = Common.getRealPath(certPath);
|
||||||
if (Common.checkPath(certPath)) {
|
if (Common.checkPath(realCertPath)) {
|
||||||
this.certPath = certPath;
|
this.certPath = realCertPath;
|
||||||
} else {
|
} else {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it before set"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <certPath> is not exist, please check it " +
|
||||||
throw new RuntimeException();
|
"before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isUseHttps() {
|
|
||||||
return useHttps;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setUseHttps(boolean useHttps) {
|
|
||||||
this.useHttps = useHttps;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getTrainDataset() {
|
public String getTrainDataset() {
|
||||||
if ("".equals(trainDataset) || trainDataset.isEmpty()) {
|
if (trainDataset == null || trainDataset.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is null or empty, please set " +
|
||||||
throw new RuntimeException();
|
"it before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return trainDataset;
|
return trainDataset;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setTrainDataset(String trainDataset) {
|
public void setTrainDataset(String trainDataset) {
|
||||||
trainDataset = Common.getRealPath(trainDataset);
|
String realTrainDataset = Common.getRealPath(trainDataset);
|
||||||
if (Common.checkPath(trainDataset)) {
|
if (Common.checkPath(realTrainDataset)) {
|
||||||
this.trainDataset = trainDataset;
|
this.trainDataset = realTrainDataset;
|
||||||
} else {
|
} else {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it before set"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainDataset> is not exist, please check it " +
|
||||||
throw new RuntimeException();
|
"before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getVocabFile() {
|
public String getVocabFile() {
|
||||||
if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
|
if ("null".equals(vocabFile) && ALBERT.equals(flName)) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is null, please set it before " +
|
||||||
throw new RuntimeException();
|
"use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return vocabFile;
|
return vocabFile;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setVocabFile(String vocabFile) {
|
public void setVocabFile(String vocabFile) {
|
||||||
vocabFile = Common.getRealPath(vocabFile);
|
String realVocabFile = Common.getRealPath(vocabFile);
|
||||||
if (Common.checkPath(vocabFile)) {
|
if (Common.checkPath(realVocabFile)) {
|
||||||
this.vocabFile = vocabFile;
|
this.vocabFile = realVocabFile;
|
||||||
} else {
|
} else {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it before set"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <vocabFile> is not exist, please check it " +
|
||||||
throw new RuntimeException();
|
"before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getIdsFile() {
|
public String getIdsFile() {
|
||||||
if ("null".equals(idsFile) && ALBERT.equals(flName)) {
|
if ("null".equals(idsFile) && ALBERT.equals(flName)) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is null, please set it before use"));
|
||||||
throw new RuntimeException();
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return idsFile;
|
return idsFile;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setIdsFile(String idsFile) {
|
public void setIdsFile(String idsFile) {
|
||||||
idsFile = Common.getRealPath(idsFile);
|
String realIdsFile = Common.getRealPath(idsFile);
|
||||||
if (Common.checkPath(idsFile)) {
|
if (Common.checkPath(realIdsFile)) {
|
||||||
this.idsFile = idsFile;
|
this.idsFile = realIdsFile;
|
||||||
} else {
|
} else {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it before set"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <idsFile> is not exist, please check it " +
|
||||||
throw new RuntimeException();
|
"before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,19 +181,21 @@ public class FLParameter {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setTestDataset(String testDataset) {
|
public void setTestDataset(String testDataset) {
|
||||||
testDataset = Common.getRealPath(testDataset);
|
String realTestDataset = Common.getRealPath(testDataset);
|
||||||
if (Common.checkPath(testDataset)) {
|
if (Common.checkPath(realTestDataset)) {
|
||||||
this.testDataset = testDataset;
|
this.testDataset = realTestDataset;
|
||||||
} else {
|
} else {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it before set"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <testDataset> is not exist, please check it " +
|
||||||
throw new RuntimeException();
|
"before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getFlName() {
|
public String getFlName() {
|
||||||
if ("".equals(flName) || flName.isEmpty()) {
|
if (flName == null || flName.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is null or empty, please set it " +
|
||||||
throw new RuntimeException();
|
"before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return flName;
|
return flName;
|
||||||
}
|
}
|
||||||
|
@ -184,61 +204,50 @@ public class FLParameter {
|
||||||
if (Common.checkFLName(flName)) {
|
if (Common.checkFLName(flName)) {
|
||||||
this.flName = flName;
|
this.flName = flName;
|
||||||
} else {
|
} else {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in flNameTrustList, please check it before set"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <flName> is not in FL_NAME_TRUST_LIST: " +
|
||||||
throw new RuntimeException();
|
Arrays.toString(Common.FL_NAME_TRUST_LIST.toArray(new String[0])) + ", please check it before " +
|
||||||
|
"set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getTrainModelPath() {
|
public String getTrainModelPath() {
|
||||||
if ("".equals(trainModelPath) || trainModelPath.isEmpty()) {
|
if (trainModelPath == null || trainModelPath.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is null or empty, please set" +
|
||||||
throw new RuntimeException();
|
" it before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return trainModelPath;
|
return trainModelPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setTrainModelPath(String trainModelPath) {
|
public void setTrainModelPath(String trainModelPath) {
|
||||||
trainModelPath = Common.getRealPath(trainModelPath);
|
String realTrainModelPath = Common.getRealPath(trainModelPath);
|
||||||
if (Common.checkPath(trainModelPath)) {
|
if (Common.checkPath(realTrainModelPath)) {
|
||||||
this.trainModelPath = trainModelPath;
|
this.trainModelPath = realTrainModelPath;
|
||||||
} else {
|
} else {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check it before set"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <trainModelPath> is not exist, please check " +
|
||||||
throw new RuntimeException();
|
"it before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getInferModelPath() {
|
public String getInferModelPath() {
|
||||||
if ("".equals(inferModelPath) || inferModelPath.isEmpty()) {
|
if (inferModelPath == null || inferModelPath.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is null or empty, please set" +
|
||||||
throw new RuntimeException();
|
" it before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return inferModelPath;
|
return inferModelPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setInferModelPath(String inferModelPath) {
|
public void setInferModelPath(String inferModelPath) {
|
||||||
inferModelPath = Common.getRealPath(inferModelPath);
|
String realInferModelPath = Common.getRealPath(inferModelPath);
|
||||||
if (Common.checkPath(inferModelPath)) {
|
if (Common.checkPath(realInferModelPath)) {
|
||||||
this.inferModelPath = inferModelPath;
|
this.inferModelPath = realInferModelPath;
|
||||||
} else {
|
} else {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check it before set"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <inferModelPath> is not exist, please check " +
|
||||||
throw new RuntimeException();
|
"it before set"));
|
||||||
}
|
throw new IllegalArgumentException();
|
||||||
}
|
|
||||||
|
|
||||||
public String getIp() {
|
|
||||||
if ("".equals(ip) || ip.isEmpty()) {
|
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is null, please set it before use"));
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
|
||||||
return ip;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setIp(String ip) {
|
|
||||||
if (Common.checkIP(ip)) {
|
|
||||||
this.ip = ip;
|
|
||||||
} else {
|
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <ip> is not valid, please check it before set"));
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,23 +259,6 @@ public class FLParameter {
|
||||||
this.useSSL = useSSL;
|
this.useSSL = useSSL;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getPort() {
|
|
||||||
if (port == 0) {
|
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is null, please set it before use"));
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
|
||||||
return port;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setPort(int port) {
|
|
||||||
if (Common.checkPort(port)) {
|
|
||||||
this.port = port;
|
|
||||||
} else {
|
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <port> is not valid, please check it before set"));
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getTimeOut() {
|
public int getTimeOut() {
|
||||||
return timeOut;
|
return timeOut;
|
||||||
}
|
}
|
||||||
|
@ -284,17 +276,18 @@ public class FLParameter {
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isUseElb() {
|
public boolean isUseElb() {
|
||||||
return useElb;
|
return ifUseElb;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setUseElb(boolean useElb) {
|
public void setUseElb(boolean ifUseElb) {
|
||||||
this.useElb = useElb;
|
this.ifUseElb = ifUseElb;
|
||||||
}
|
}
|
||||||
|
|
||||||
public int getServerNum() {
|
public int getServerNum() {
|
||||||
if (serverNum <= 0) {
|
if (serverNum <= 0) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> is <= 0, it should be > 0, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <serverNum> <= 0, it should be > 0, please " +
|
||||||
throw new RuntimeException();
|
"set it before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return serverNum;
|
return serverNum;
|
||||||
}
|
}
|
||||||
|
@ -303,40 +296,11 @@ public class FLParameter {
|
||||||
this.serverNum = serverNum;
|
this.serverNum = serverNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean isTimer() {
|
|
||||||
return timer;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setTimer(boolean timer) {
|
|
||||||
this.timer = timer;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getTimeWindow() {
|
|
||||||
return timeWindow;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setTimeWindow(int timeWindow) {
|
|
||||||
this.timeWindow = timeWindow;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getReRequestNum() {
|
|
||||||
return reRequestNum;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setReRequestNum(int reRequestNum) {
|
|
||||||
this.reRequestNum = reRequestNum;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getClientID() {
|
public String getClientID() {
|
||||||
if ("".equals(clientID) || clientID.isEmpty()) {
|
if (clientID == null || clientID.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[flParameter] the parameter of <clientID> is null or empty, please check"));
|
||||||
throw new RuntimeException();
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return clientID;
|
return clientID;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setClientID(String clientID) {
|
|
||||||
this.clientID = clientID;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,13 +13,19 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
import com.mindspore.flclient.model.AlInferBert;
|
import com.mindspore.flclient.model.AlInferBert;
|
||||||
import com.mindspore.flclient.model.AlTrainBert;
|
import com.mindspore.flclient.model.AlTrainBert;
|
||||||
import com.mindspore.flclient.model.SessionUtil;
|
import com.mindspore.flclient.model.SessionUtil;
|
||||||
import com.mindspore.flclient.model.TrainLenet;
|
import com.mindspore.flclient.model.TrainLenet;
|
||||||
|
|
||||||
import mindspore.schema.FeatureMap;
|
import mindspore.schema.FeatureMap;
|
||||||
import mindspore.schema.RequestGetModel;
|
import mindspore.schema.RequestGetModel;
|
||||||
import mindspore.schema.ResponseCode;
|
import mindspore.schema.ResponseCode;
|
||||||
|
@ -29,57 +35,30 @@ import java.util.ArrayList;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
/**
|
||||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
* Define the serialization method, handle the response message returned from server for getModel request.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class GetModel {
|
public class GetModel {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
|
||||||
|
private static volatile GetModel getModel;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
System.loadLibrary("mindspore-lite-jni");
|
System.loadLibrary("mindspore-lite-jni");
|
||||||
}
|
}
|
||||||
|
|
||||||
class RequestGetModelBuilder {
|
|
||||||
private FlatBufferBuilder builder;
|
|
||||||
private int nameOffset = 0;
|
|
||||||
private int iteration = 0;
|
|
||||||
private int timeStampOffset = 0;
|
|
||||||
|
|
||||||
public RequestGetModelBuilder() {
|
|
||||||
builder = new FlatBufferBuilder();
|
|
||||||
}
|
|
||||||
|
|
||||||
public RequestGetModelBuilder flName(String name) {
|
|
||||||
this.nameOffset = this.builder.createString(name);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RequestGetModelBuilder time() {
|
|
||||||
Date date = new Date();
|
|
||||||
long time = date.getTime();
|
|
||||||
this.timeStampOffset = builder.createString(String.valueOf(time));
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RequestGetModelBuilder iteration(int iteration) {
|
|
||||||
this.iteration = iteration;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public byte[] build() {
|
|
||||||
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
|
|
||||||
builder.finish(root);
|
|
||||||
return builder.sizedByteArray();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final Logger LOGGER = Logger.getLogger(GetModel.class.toString());
|
|
||||||
private static volatile GetModel getModel;
|
|
||||||
|
|
||||||
private GetModel() {
|
|
||||||
}
|
|
||||||
|
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||||
|
|
||||||
|
private GetModel() {
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the singleton object of the class GetModel.
|
||||||
|
*
|
||||||
|
* @return the singleton object of the class GetModel.
|
||||||
|
*/
|
||||||
public static GetModel getInstance() {
|
public static GetModel getInstance() {
|
||||||
GetModel localRef = getModel;
|
GetModel localRef = getModel;
|
||||||
if (localRef == null) {
|
if (localRef == null) {
|
||||||
|
@ -93,7 +72,18 @@ public class GetModel {
|
||||||
return localRef;
|
return localRef;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a flatBuffer builder of RequestGetModel.
|
||||||
|
*
|
||||||
|
* @param name the model name.
|
||||||
|
* @param iteration current iteration of federated learning task.
|
||||||
|
* @return the flatBuffer builder of RequestGetModel in byte[] format.
|
||||||
|
*/
|
||||||
public byte[] getRequestGetModel(String name, int iteration) {
|
public byte[] getRequestGetModel(String name, int iteration) {
|
||||||
|
if (name == null || name.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[GetModel] the input parameter of <name> is null or empty, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
RequestGetModelBuilder builder = new RequestGetModelBuilder();
|
RequestGetModelBuilder builder = new RequestGetModelBuilder();
|
||||||
return builder.iteration(iteration).flName(name).time().build();
|
return builder.iteration(iteration).flName(name).time().build();
|
||||||
}
|
}
|
||||||
|
@ -107,6 +97,10 @@ public class GetModel {
|
||||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < fmCount; i++) {
|
||||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||||
|
if (feature == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
String featureName = feature.weightFullname();
|
String featureName = feature.weightFullname();
|
||||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||||
albertFeatureMaps.add(feature);
|
albertFeatureMaps.add(feature);
|
||||||
|
@ -116,36 +110,46 @@ public class GetModel {
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
LOGGER.info(Common.addTag("[getModel] weightFullname: " + feature.weightFullname() + ", weightLength:" +
|
||||||
|
" " + feature.dataLength()));
|
||||||
}
|
}
|
||||||
int tag = 0;
|
int tag = 0;
|
||||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference model-----------------"));
|
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into inference " +
|
||||||
|
"model-----------------"));
|
||||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
|
||||||
|
inferFeatureMaps);
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
|
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into train model-----------------"));
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||||
|
albertFeatureMaps);
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
} else if (localFLParameter.getServerMod().equals(ServerMod.FEDERATED_LEARNING.toString())) {
|
||||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < fmCount; i++) {
|
||||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||||
|
if (feature == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
String featureName = feature.weightFullname();
|
String featureName = feature.weightFullname();
|
||||||
featureMaps.add(feature);
|
featureMaps.add(feature);
|
||||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
|
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
|
||||||
|
feature.dataLength()));
|
||||||
}
|
}
|
||||||
int tag = 0;
|
int tag = 0;
|
||||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||||
|
featureMaps);
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
LOGGER.severe(Common.addTag("[getModel] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
|
@ -159,9 +163,14 @@ public class GetModel {
|
||||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < fmCount; i++) {
|
||||||
FeatureMap feature = responseDataBuf.featureMap(i);
|
FeatureMap feature = responseDataBuf.featureMap(i);
|
||||||
|
if (feature == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[getModel] the feature returned from server is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
String featureName = feature.weightFullname();
|
String featureName = feature.weightFullname();
|
||||||
featureMaps.add(feature);
|
featureMaps.add(feature);
|
||||||
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " + feature.dataLength()));
|
LOGGER.info(Common.addTag("[getModel] weightFullname: " + featureName + ", weightLength: " +
|
||||||
|
feature.dataLength()));
|
||||||
}
|
}
|
||||||
int tag = 0;
|
int tag = 0;
|
||||||
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
LOGGER.info(Common.addTag("[getModel] ----------------loading weight into model-----------------"));
|
||||||
|
@ -174,7 +183,12 @@ public class GetModel {
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle the response message returned from server.
|
||||||
|
*
|
||||||
|
* @param responseDataBuf the response message returned from server.
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus doResponse(ResponseGetModel responseDataBuf) {
|
public FLClientStatus doResponse(ResponseGetModel responseDataBuf) {
|
||||||
LOGGER.info(Common.addTag("[getModel] ==========get model content is:================"));
|
LOGGER.info(Common.addTag("[getModel] ==========get model content is:================"));
|
||||||
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode()));
|
LOGGER.info(Common.addTag("[getModel] ==========retCode: " + responseDataBuf.retcode()));
|
||||||
|
@ -186,13 +200,15 @@ public class GetModel {
|
||||||
switch (retCode) {
|
switch (retCode) {
|
||||||
case (ResponseCode.SUCCEED):
|
case (ResponseCode.SUCCEED):
|
||||||
LOGGER.info(Common.addTag("[getModel] getModel response success"));
|
LOGGER.info(Common.addTag("[getModel] getModel response success"));
|
||||||
|
|
||||||
if (ALBERT.equals(flParameter.getFlName())) {
|
if (ALBERT.equals(flParameter.getFlName())) {
|
||||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
LOGGER.info(Common.addTag("[getModel] into <parseResponseAlbert>"));
|
||||||
status = parseResponseAlbert(responseDataBuf);
|
status = parseResponseAlbert(responseDataBuf);
|
||||||
} else if (LENET.equals(flParameter.getFlName())) {
|
} else if (LENET.equals(flParameter.getFlName())) {
|
||||||
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
LOGGER.info(Common.addTag("[getModel] into <parseResponseLenet>"));
|
||||||
status = parseResponseLenet(responseDataBuf);
|
status = parseResponseLenet(responseDataBuf);
|
||||||
|
} else {
|
||||||
|
LOGGER.severe(Common.addTag("[getModel] the flName is not valid, only support: lenet, albert"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return status;
|
return status;
|
||||||
case (ResponseCode.SucNotReady):
|
case (ResponseCode.SucNotReady):
|
||||||
|
@ -211,4 +227,42 @@ public class GetModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class RequestGetModelBuilder {
|
||||||
|
private FlatBufferBuilder builder;
|
||||||
|
private int nameOffset = 0;
|
||||||
|
private int iteration = 0;
|
||||||
|
private int timeStampOffset = 0;
|
||||||
|
|
||||||
|
public RequestGetModelBuilder() {
|
||||||
|
builder = new FlatBufferBuilder();
|
||||||
|
}
|
||||||
|
|
||||||
|
private RequestGetModelBuilder flName(String name) {
|
||||||
|
if (name == null || name.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[GetModel] the input parameter of <name> is null or empty, please " +
|
||||||
|
"check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.nameOffset = this.builder.createString(name);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private RequestGetModelBuilder time() {
|
||||||
|
Date date = new Date();
|
||||||
|
long time = date.getTime();
|
||||||
|
this.timeStampOffset = builder.createString(String.valueOf(time));
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private RequestGetModelBuilder iteration(int iteration) {
|
||||||
|
this.iteration = iteration;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private byte[] build() {
|
||||||
|
int root = RequestGetModel.createRequestGetModel(this.builder, nameOffset, iteration, timeStampOffset);
|
||||||
|
builder.finish(root);
|
||||||
|
return builder.sizedByteArray();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,25 +1,42 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
* You may obtain a copy of the License at
|
* You may obtain a copy of the License at
|
||||||
*
|
*
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define asynchronous communication call back interface.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public interface IAsyncCallBack {
|
public interface IAsyncCallBack {
|
||||||
public FLClientStatus onFailure(IOException exception);
|
/**
|
||||||
|
* Automatically invoked when the request fails.
|
||||||
|
*
|
||||||
|
* @param exception the catch exception.
|
||||||
|
* @return the status code in client.
|
||||||
|
*/
|
||||||
|
FLClientStatus onFailure(IOException exception);
|
||||||
|
|
||||||
public FLClientStatus onResponse(byte[] msg);
|
/**
|
||||||
|
* Automatically invoked when a response message is processed.
|
||||||
|
*
|
||||||
|
* @param msg the response message.
|
||||||
|
* @return the status code in client.
|
||||||
|
*/
|
||||||
|
FLClientStatus onResponse(byte[] msg);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,33 +1,54 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
* You may obtain a copy of the License at
|
* You may obtain a copy of the License at
|
||||||
*
|
*
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
import java.util.concurrent.TimeoutException;
|
import java.util.concurrent.TimeoutException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author smurf
|
* Define basic communication interface.
|
||||||
*
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
*/
|
*/
|
||||||
public interface IFLCommunication {
|
public interface IFLCommunication {
|
||||||
|
/**
|
||||||
|
* Sets the timeout interval for communication on the device.
|
||||||
|
*
|
||||||
|
* @param timeout the timeout interval for communication on the device.
|
||||||
|
* @throws TimeoutException catch TimeoutException.
|
||||||
|
*/
|
||||||
|
void setTimeOut(int timeout) throws TimeoutException;
|
||||||
|
|
||||||
public void setTimeOut(int timeout) throws TimeoutException;
|
/**
|
||||||
|
* Synchronization request function.
|
||||||
public byte[] syncRequest(String url, byte[] msg) throws Exception;
|
*
|
||||||
|
* @param url the URL for device-sever interaction set by user.
|
||||||
public void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception;
|
* @param msg the message need to be sent to server.
|
||||||
}
|
* @return the response message.
|
||||||
|
* @throws Exception catch Exception.
|
||||||
|
*/
|
||||||
|
byte[] syncRequest(String url, byte[] msg) throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Asynchronous request function.
|
||||||
|
*
|
||||||
|
* @param url the URL for device-sever interaction set by user.
|
||||||
|
* @param msg the message need to be sent to server.
|
||||||
|
* @param callBack the call back object.
|
||||||
|
* @throws Exception catch Exception.
|
||||||
|
*/
|
||||||
|
void asyncRequest(String url, byte[] msg, IAsyncCallBack callBack) throws Exception;
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,22 +13,30 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define job result callback function interface.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public interface IFLJobResultCallback {
|
public interface IFLJobResultCallback {
|
||||||
/**
|
/**
|
||||||
* Called at the end of an iteration for Fl job
|
* Called at the end of an iteration for Fl job
|
||||||
* @param modelName the name of model
|
*
|
||||||
|
* @param modelName the name of model
|
||||||
* @param iterationSeq Iteration number
|
* @param iterationSeq Iteration number
|
||||||
* @param resultCode Status Code
|
* @param resultCode Status Code
|
||||||
*/
|
*/
|
||||||
public void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode);
|
void onFlJobIterationFinished(String modelName, int iterationSeq, int resultCode);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called on completion for Fl job
|
* Called on completion for Fl job
|
||||||
* @param modelName the name of model
|
*
|
||||||
|
* @param modelName the name of model
|
||||||
* @param iterationCount total Iteration numbers
|
* @param iterationCount total Iteration numbers
|
||||||
* @param resultCode Status Code
|
* @param resultCode Status Code
|
||||||
*/
|
*/
|
||||||
public void onFlJobFinished(String modelName, int iterationCount, int resultCode);
|
void onFlJobFinished(String modelName, int iterationCount, int resultCode);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,28 +13,60 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
import org.bouncycastle.math.ec.rfc7748.X25519;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defines global parameters used internally during federated learning.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class LocalFLParameter {
|
public class LocalFLParameter {
|
||||||
private static final Logger LOGGER = Logger.getLogger(LocalFLParameter.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(LocalFLParameter.class.toString());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Seed length used to generate random perturbations
|
||||||
|
*/
|
||||||
public static final int SEED_SIZE = 32;
|
public static final int SEED_SIZE = 32;
|
||||||
public static final int IVEC_LEN = 16;
|
|
||||||
|
/**
|
||||||
|
* The length of IV value
|
||||||
|
*/
|
||||||
|
public static final int I_VEC_LEN = 16;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The length of salt value
|
||||||
|
*/
|
||||||
|
public static final int SALT_SIZE = 32;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the key length
|
||||||
|
*/
|
||||||
|
public static final int KEY_LEN = X25519.SCALAR_SIZE;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model name supported by federated learning tasks: "lenet".
|
||||||
|
*/
|
||||||
public static final String LENET = "lenet";
|
public static final String LENET = "lenet";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model name supported by federated learning tasks: "albert".
|
||||||
|
*/
|
||||||
public static final String ALBERT = "albert";
|
public static final String ALBERT = "albert";
|
||||||
|
private static volatile LocalFLParameter localFLParameter;
|
||||||
|
|
||||||
private List<String> classifierWeightName = new ArrayList<>();
|
private List<String> classifierWeightName = new ArrayList<>();
|
||||||
private List<String> albertWeightName = new ArrayList<>();
|
private List<String> albertWeightName = new ArrayList<>();
|
||||||
|
|
||||||
private String flID;
|
private String flID;
|
||||||
private String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString();
|
private String encryptLevel = EncryptLevel.NOT_ENCRYPT.toString();
|
||||||
private String earlyStopMod = EarlyStopMod.NOT_EARLY_STOP.toString();
|
private String earlyStopMod = EarlyStopMod.NOT_EARLY_STOP.toString();
|
||||||
private String serverMod = ServerMod.HYBRID_TRAINING.toString();
|
private String serverMod = ServerMod.HYBRID_TRAINING.toString();
|
||||||
private String safeMod = "The cluster is in safemode.";
|
|
||||||
|
|
||||||
private static volatile LocalFLParameter localFLParameter;
|
|
||||||
|
|
||||||
private LocalFLParameter() {
|
private LocalFLParameter() {
|
||||||
// set classifierWeightName albertWeightName
|
// set classifierWeightName albertWeightName
|
||||||
|
@ -42,6 +74,11 @@ public class LocalFLParameter {
|
||||||
Common.setAlbertWeightName(albertWeightName);
|
Common.setAlbertWeightName(albertWeightName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the singleton object of the class LocalFLParameter.
|
||||||
|
*
|
||||||
|
* @return the singleton object of the class LocalFLParameter.
|
||||||
|
*/
|
||||||
public static LocalFLParameter getInstance() {
|
public static LocalFLParameter getInstance() {
|
||||||
LocalFLParameter localRef = localFLParameter;
|
LocalFLParameter localRef = localFLParameter;
|
||||||
if (localRef == null) {
|
if (localRef == null) {
|
||||||
|
@ -57,8 +94,9 @@ public class LocalFLParameter {
|
||||||
|
|
||||||
public List<String> getClassifierWeightName() {
|
public List<String> getClassifierWeightName() {
|
||||||
if (classifierWeightName.isEmpty()) {
|
if (classifierWeightName.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please " +
|
||||||
throw new RuntimeException();
|
"set it before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return classifierWeightName;
|
return classifierWeightName;
|
||||||
}
|
}
|
||||||
|
@ -69,8 +107,9 @@ public class LocalFLParameter {
|
||||||
|
|
||||||
public List<String> getAlbertWeightName() {
|
public List<String> getAlbertWeightName() {
|
||||||
if (albertWeightName.isEmpty()) {
|
if (albertWeightName.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <classifierWeightName> is null, please " +
|
||||||
throw new RuntimeException();
|
"set it before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return albertWeightName;
|
return albertWeightName;
|
||||||
}
|
}
|
||||||
|
@ -80,14 +119,20 @@ public class LocalFLParameter {
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getFlID() {
|
public String getFlID() {
|
||||||
if ("".equals(flID) || flID == null) {
|
if (flID == null || flID.isEmpty()) {
|
||||||
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before use"));
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please set it before " +
|
||||||
throw new RuntimeException();
|
"use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
return flID;
|
return flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setFlID(String flID) {
|
public void setFlID(String flID) {
|
||||||
|
if (flID == null || flID.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <flID> is null, please check it before " +
|
||||||
|
"set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
this.flID = flID;
|
this.flID = flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,6 +141,18 @@ public class LocalFLParameter {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setEncryptLevel(String encryptLevel) {
|
public void setEncryptLevel(String encryptLevel) {
|
||||||
|
if (encryptLevel == null || encryptLevel.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is null, please check it " +
|
||||||
|
"before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if ((!EncryptLevel.DP_ENCRYPT.toString().equals(encryptLevel)) &&
|
||||||
|
(!EncryptLevel.NOT_ENCRYPT.toString().equals(encryptLevel)) &&
|
||||||
|
(!EncryptLevel.PW_ENCRYPT.toString().equals(encryptLevel))) {
|
||||||
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <encryptLevel> is " + encryptLevel + " ," +
|
||||||
|
" it must be DP_ENCRYPT or NOT_ENCRYPT or PW_ENCRYPT, please check it before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
this.encryptLevel = encryptLevel;
|
this.encryptLevel = encryptLevel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,6 +161,19 @@ public class LocalFLParameter {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setEarlyStopMod(String earlyStopMod) {
|
public void setEarlyStopMod(String earlyStopMod) {
|
||||||
|
if (earlyStopMod == null || earlyStopMod.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is null, please check it " +
|
||||||
|
"before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if ((!EarlyStopMod.NOT_EARLY_STOP.toString().equals(earlyStopMod)) &&
|
||||||
|
(!EarlyStopMod.LOSS_ABS.toString().equals(earlyStopMod)) &&
|
||||||
|
(!EarlyStopMod.LOSS_DIFF.toString().equals(earlyStopMod)) &&
|
||||||
|
(!EarlyStopMod.WEIGHT_DIFF.toString().equals(earlyStopMod))) {
|
||||||
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <earlyStopMod> is " + earlyStopMod + " ," +
|
||||||
|
" it must be NOT_EARLY_STOP or LOSS_ABS or LOSS_DIFF or WEIGHT_DIFF, please check it before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
this.earlyStopMod = earlyStopMod;
|
this.earlyStopMod = earlyStopMod;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,14 +182,17 @@ public class LocalFLParameter {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setServerMod(String serverMod) {
|
public void setServerMod(String serverMod) {
|
||||||
|
if (serverMod == null || serverMod.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is null, please check it " +
|
||||||
|
"before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if ((!ServerMod.HYBRID_TRAINING.toString().equals(serverMod)) &&
|
||||||
|
(!ServerMod.FEDERATED_LEARNING.toString().equals(serverMod))) {
|
||||||
|
LOGGER.severe(Common.addTag("[localFLParameter] the parameter of <serverMod> is " + serverMod + " , it " +
|
||||||
|
"must be HYBRID_TRAINING or FEDERATED_LEARNING, please check it before set"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
this.serverMod = serverMod;
|
this.serverMod = serverMod;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getSafeMod() {
|
|
||||||
return safeMod;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSafeMod(String safeMod) {
|
|
||||||
this.safeMod = safeMod;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,33 +13,73 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.security.InvalidKeyException;
|
||||||
|
import java.security.KeyManagementException;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import java.security.NoSuchProviderException;
|
||||||
|
import java.security.SignatureException;
|
||||||
|
import java.security.cert.Certificate;
|
||||||
|
import java.security.cert.CertificateException;
|
||||||
|
import java.security.cert.CertificateFactory;
|
||||||
|
import java.security.cert.X509Certificate;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import javax.net.ssl.HostnameVerifier;
|
import javax.net.ssl.HostnameVerifier;
|
||||||
import javax.net.ssl.SSLContext;
|
import javax.net.ssl.SSLContext;
|
||||||
import javax.net.ssl.SSLSession;
|
import javax.net.ssl.SSLSession;
|
||||||
import javax.net.ssl.SSLSocketFactory;
|
import javax.net.ssl.SSLSocketFactory;
|
||||||
import javax.net.ssl.TrustManager;
|
import javax.net.ssl.TrustManager;
|
||||||
import javax.net.ssl.X509TrustManager;
|
import javax.net.ssl.X509TrustManager;
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.security.InvalidKeyException;
|
|
||||||
import java.security.NoSuchAlgorithmException;
|
|
||||||
import java.security.NoSuchProviderException;
|
|
||||||
import java.security.SignatureException;
|
|
||||||
import java.security.cert.CertificateException;
|
|
||||||
import java.security.cert.CertificateFactory;
|
|
||||||
import java.security.cert.X509Certificate;
|
|
||||||
import java.util.logging.Logger;
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define SSL socket factory tools for https communication.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class SSLSocketFactoryTools {
|
public class SSLSocketFactoryTools {
|
||||||
private static final Logger LOGGER = Logger.getLogger(SSLSocketFactory.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(SSLSocketFactory.class.toString());
|
||||||
|
private static volatile SSLSocketFactoryTools sslSocketFactoryTools;
|
||||||
|
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
private X509Certificate x509Certificate;
|
private X509Certificate x509Certificate;
|
||||||
private SSLSocketFactory sslSocketFactory;
|
private SSLSocketFactory sslSocketFactory;
|
||||||
private SSLContext sslContext;
|
private SSLContext sslContext;
|
||||||
private MyTrustManager myTrustManager;
|
private MyTrustManager myTrustManager;
|
||||||
private static volatile SSLSocketFactoryTools sslSocketFactoryTools;
|
private final HostnameVerifier hostnameVerifier = new HostnameVerifier() {
|
||||||
|
@Override
|
||||||
|
public boolean verify(String hostname, SSLSession session) {
|
||||||
|
if (hostname == null || hostname.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <hostname> is null or empty, " +
|
||||||
|
"please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if (session == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <session> is null, please " +
|
||||||
|
"check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
String domainName = flParameter.getDomainName();
|
||||||
|
if ((domainName == null || domainName.isEmpty() || domainName.split("//").length < 2)) {
|
||||||
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the <domainName> is null or not valid, it should" +
|
||||||
|
" be like as https://...... , please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if (domainName.split("//")[1].split(":").length < 2) {
|
||||||
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the format of <domainName> is not valid, it " +
|
||||||
|
"should be like as https://127.0.0.1:6666 when setting <useSSL> to true, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
String ip = domainName.split("//")[1].split(":")[0];
|
||||||
|
return hostname.equals(ip);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
private SSLSocketFactoryTools() {
|
private SSLSocketFactoryTools() {
|
||||||
initSslSocketFactory();
|
initSslSocketFactory();
|
||||||
|
@ -52,14 +92,19 @@ public class SSLSocketFactoryTools {
|
||||||
myTrustManager = new MyTrustManager(x509Certificate);
|
myTrustManager = new MyTrustManager(x509Certificate);
|
||||||
sslContext.init(null, new TrustManager[]{
|
sslContext.init(null, new TrustManager[]{
|
||||||
myTrustManager
|
myTrustManager
|
||||||
}, new java.security.SecureRandom());
|
}, Common.getSecureRandom());
|
||||||
sslSocketFactory = sslContext.getSocketFactory();
|
sslSocketFactory = sslContext.getSocketFactory();
|
||||||
|
} catch (NoSuchAlgorithmException | KeyManagementException ex) {
|
||||||
} catch (Exception e) {
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " +
|
||||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools]catch Exception in initSslSocketFactory: " + e.getMessage()));
|
ex.getMessage()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the singleton object of the class SSLSocketFactoryTools.
|
||||||
|
*
|
||||||
|
* @return the singleton object of the class SSLSocketFactoryTools.
|
||||||
|
*/
|
||||||
public static SSLSocketFactoryTools getInstance() {
|
public static SSLSocketFactoryTools getInstance() {
|
||||||
SSLSocketFactoryTools localRef = sslSocketFactoryTools;
|
SSLSocketFactoryTools localRef = sslSocketFactoryTools;
|
||||||
if (localRef == null) {
|
if (localRef == null) {
|
||||||
|
@ -73,29 +118,37 @@ public class SSLSocketFactoryTools {
|
||||||
return localRef;
|
return localRef;
|
||||||
}
|
}
|
||||||
|
|
||||||
public X509Certificate readCert(String assetName) {
|
private X509Certificate readCert(String assetName) {
|
||||||
InputStream inputStream = null;
|
if (assetName == null || assetName.isEmpty()) {
|
||||||
try {
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] the parameter of <assetName> is null or empty, " +
|
||||||
inputStream = new FileInputStream(assetName);
|
"please check!"));
|
||||||
} catch (Exception e) {
|
|
||||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of read inputStream in readCert: " + e.getMessage()));
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
InputStream inputStream = null;
|
||||||
X509Certificate cert = null;
|
X509Certificate cert = null;
|
||||||
try {
|
try {
|
||||||
|
inputStream = new FileInputStream(assetName);
|
||||||
CertificateFactory cf = CertificateFactory.getInstance("X.509");
|
CertificateFactory cf = CertificateFactory.getInstance("X.509");
|
||||||
cert = (X509Certificate) cf.generateCertificate(inputStream);
|
Certificate certificate = cf.generateCertificate(inputStream);
|
||||||
} catch (Exception e) {
|
if (certificate instanceof X509Certificate) {
|
||||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch Exception of creating CertificateFactory in readCert: " + e.getMessage()));
|
cert = (X509Certificate) certificate;
|
||||||
|
} else {
|
||||||
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] cf.generateCertificate(inputStream) can not " +
|
||||||
|
"convert to X509Certificate"));
|
||||||
|
}
|
||||||
|
} catch (FileNotFoundException | CertificateException ex) {
|
||||||
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch FileNotFoundException or CertificateException " +
|
||||||
|
"when creating " +
|
||||||
|
"CertificateFactory in readCert: " + ex.getMessage()));
|
||||||
} finally {
|
} finally {
|
||||||
try {
|
try {
|
||||||
if (inputStream != null) {
|
if (inputStream != null) {
|
||||||
inputStream.close();
|
inputStream.close();
|
||||||
}
|
}
|
||||||
} catch (Throwable ex) {
|
} catch (IOException ex) {
|
||||||
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch IOException: " + ex.getMessage()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return cert;
|
return cert;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,7 +164,6 @@ public class SSLSocketFactoryTools {
|
||||||
return myTrustManager;
|
return myTrustManager;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private static final class MyTrustManager implements X509TrustManager {
|
private static final class MyTrustManager implements X509TrustManager {
|
||||||
X509Certificate cert;
|
X509Certificate cert;
|
||||||
|
|
||||||
|
@ -126,27 +178,25 @@ public class SSLSocketFactoryTools {
|
||||||
@Override
|
@Override
|
||||||
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
|
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
|
||||||
for (X509Certificate cert : chain) {
|
for (X509Certificate cert : chain) {
|
||||||
|
|
||||||
// Make sure that it hasn't expired.
|
// Make sure that it hasn't expired.
|
||||||
cert.checkValidity();
|
cert.checkValidity();
|
||||||
|
|
||||||
// Verify the certificate's public key chain.
|
// Verify the certificate's public key chain.
|
||||||
try {
|
try {
|
||||||
cert.verify(((X509Certificate) this.cert).getPublicKey());
|
cert.verify(this.cert.getPublicKey());
|
||||||
} catch (NoSuchAlgorithmException e) {
|
} catch (NoSuchAlgorithmException e) {
|
||||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage()));
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
|
||||||
throw new RuntimeException();
|
"NoSuchAlgorithmException in checkServerTrusted: " + e.getMessage()));
|
||||||
} catch (InvalidKeyException e) {
|
} catch (InvalidKeyException e) {
|
||||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch InvalidKeyException in checkServerTrusted: " + e.getMessage()));
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
|
||||||
throw new RuntimeException();
|
"InvalidKeyException in checkServerTrusted: " + e.getMessage()));
|
||||||
} catch (NoSuchProviderException e) {
|
} catch (NoSuchProviderException e) {
|
||||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch NoSuchProviderException in checkServerTrusted: " + e.getMessage()));
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
|
||||||
throw new RuntimeException();
|
"NoSuchProviderException in checkServerTrusted: " + e.getMessage()));
|
||||||
} catch (SignatureException e) {
|
} catch (SignatureException e) {
|
||||||
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] catch SignatureException in checkServerTrusted: " + e.getMessage()));
|
LOGGER.severe(Common.addTag("[SSLSocketFactoryTools] checkServerTrusted failed, catch " +
|
||||||
throw new RuntimeException();
|
"SignatureException in checkServerTrusted: " + e.getMessage()));
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("checkServerTrusted success!"));
|
LOGGER.info(Common.addTag("**********************checkServerTrusted success!**********************"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,14 +205,4 @@ public class SSLSocketFactoryTools {
|
||||||
return new java.security.cert.X509Certificate[0];
|
return new java.security.cert.X509Certificate[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private final HostnameVerifier hostnameVerifier = new HostnameVerifier() {
|
|
||||||
@Override
|
|
||||||
public boolean verify(String hostname, SSLSession session) {
|
|
||||||
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] server hostname: " + flParameter.getHostName()));
|
|
||||||
LOGGER.info(Common.addTag("[SSLSocketFactoryTools] client request hostname: " + hostname));
|
|
||||||
return hostname.equals(flParameter.getHostName());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,127 +13,179 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
import com.mindspore.flclient.model.AlTrainBert;
|
import com.mindspore.flclient.model.AlTrainBert;
|
||||||
import com.mindspore.flclient.model.SessionUtil;
|
import com.mindspore.flclient.model.SessionUtil;
|
||||||
import com.mindspore.flclient.model.TrainLenet;
|
import com.mindspore.flclient.model.TrainLenet;
|
||||||
|
|
||||||
import mindspore.schema.FeatureMap;
|
import mindspore.schema.FeatureMap;
|
||||||
|
|
||||||
import java.security.SecureRandom;
|
import java.security.SecureRandom;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Random;
|
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
/**
|
||||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
* Defines encryption and decryption methods.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class SecureProtocol {
|
public class SecureProtocol {
|
||||||
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(SecureProtocol.class.toString());
|
||||||
|
private static double deltaError = 1e-6d;
|
||||||
|
private static Map<String, float[]> modelMap;
|
||||||
|
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||||
private int iteration;
|
private int iteration;
|
||||||
private CipherClient cipher;
|
private CipherClient cipherClient;
|
||||||
private FLClientStatus status;
|
private FLClientStatus status;
|
||||||
private float[] featureMask = new float[0];
|
private float[] featureMask = new float[0];
|
||||||
private double dpEps;
|
private double dpEps;
|
||||||
private double dpDelta;
|
private double dpDelta;
|
||||||
private double dpNormClip;
|
private double dpNormClip;
|
||||||
private static double deltaError = 1e-6;
|
|
||||||
private static Map<String, float[]> modelMap;
|
|
||||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||||
private int retCode;
|
private int retCode;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain current status code in client.
|
||||||
|
*
|
||||||
|
* @return current status code in client.
|
||||||
|
*/
|
||||||
public FLClientStatus getStatus() {
|
public FLClientStatus getStatus() {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
public float[] getFeatureMask() {
|
/**
|
||||||
return featureMask;
|
* Obtain retCode returned by server.
|
||||||
}
|
*
|
||||||
|
* @return the retCode returned by server.
|
||||||
|
*/
|
||||||
public int getRetCode() {
|
public int getRetCode() {
|
||||||
return retCode;
|
return retCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
public SecureProtocol() {
|
/**
|
||||||
}
|
* Setting parameters for pairwise masking.
|
||||||
|
*
|
||||||
|
* @param iter current iteration of federated learning task.
|
||||||
|
* @param minSecretNum minimum number of secret fragments required to reconstruct a secret
|
||||||
|
* @param prime teh big prime number used to split secrets into pieces
|
||||||
|
* @param featureSize the total feature size in model
|
||||||
|
*/
|
||||||
public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) {
|
public void setPWParameter(int iter, int minSecretNum, byte[] prime, int featureSize) {
|
||||||
this.iteration = iter;
|
if (prime == null || prime.length == 0) {
|
||||||
this.cipher = new CipherClient(iteration, minSecretNum, prime, featureSize);
|
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <prime> is null, please check!"));
|
||||||
}
|
throw new IllegalArgumentException();
|
||||||
|
|
||||||
public FLClientStatus setDPParameter(int iter, double diffEps,
|
|
||||||
double diffDelta, double diffNorm, Map<String, float[]> map) {
|
|
||||||
try {
|
|
||||||
this.iteration = iter;
|
|
||||||
this.dpEps = diffEps;
|
|
||||||
this.dpDelta = diffDelta;
|
|
||||||
this.dpNormClip = diffNorm;
|
|
||||||
this.modelMap = map;
|
|
||||||
status = FLClientStatus.SUCCESS;
|
|
||||||
} catch (Exception e) {
|
|
||||||
LOGGER.severe(Common.addTag("[DPEncrypt] catch Exception in setDPParameter: " + e.getMessage()));
|
|
||||||
status = FLClientStatus.FAILED;
|
|
||||||
}
|
}
|
||||||
return status;
|
this.iteration = iter;
|
||||||
|
this.cipherClient = new CipherClient(iteration, minSecretNum, prime, featureSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Setting parameters for differential privacy.
|
||||||
|
*
|
||||||
|
* @param iter current iteration of federated learning task.
|
||||||
|
* @param diffEps privacy budget eps of DP mechanism.
|
||||||
|
* @param diffDelta privacy budget delta of DP mechanism.
|
||||||
|
* @param diffNorm normClip factor of DP mechanism.
|
||||||
|
* @param map model weights.
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
|
public FLClientStatus setDPParameter(int iter, double diffEps, double diffDelta, double diffNorm, Map<String,
|
||||||
|
float[]> map) {
|
||||||
|
this.iteration = iter;
|
||||||
|
this.dpEps = diffEps;
|
||||||
|
this.dpDelta = diffDelta;
|
||||||
|
this.dpNormClip = diffNorm;
|
||||||
|
this.modelMap = map;
|
||||||
|
return FLClientStatus.SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain the feature names that needed to be encrypted.
|
||||||
|
*
|
||||||
|
* @return the feature names that needed to be encrypted.
|
||||||
|
*/
|
||||||
public ArrayList<String> getEncryptFeatureName() {
|
public ArrayList<String> getEncryptFeatureName() {
|
||||||
return encryptFeatureName;
|
return encryptFeatureName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the parameter encryptFeatureName.
|
||||||
|
*
|
||||||
|
* @param encryptFeatureName the feature names that needed to be encrypted.
|
||||||
|
*/
|
||||||
public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) {
|
public void setEncryptFeatureName(ArrayList<String> encryptFeatureName) {
|
||||||
this.encryptFeatureName = encryptFeatureName;
|
this.encryptFeatureName = encryptFeatureName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtain the returned timestamp for next request from server.
|
||||||
|
*
|
||||||
|
* @return the timestamp for next request.
|
||||||
|
*/
|
||||||
public String getNextRequestTime() {
|
public String getNextRequestTime() {
|
||||||
return cipher.getNextRequestTime();
|
return cipherClient.getNextRequestTime();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate pairwise mask and individual mask.
|
||||||
|
*
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus pwCreateMask() {
|
public FLClientStatus pwCreateMask() {
|
||||||
LOGGER.info("[PairWiseMask] ==============request flID: " + localFLParameter.getFlID() + "==============");
|
LOGGER.info(String.format("[PairWiseMask] ==============request flID: %s ==============",
|
||||||
|
localFLParameter.getFlID()));
|
||||||
// round 0
|
// round 0
|
||||||
status = cipher.exchangeKeys();
|
status = cipherClient.exchangeKeys();
|
||||||
retCode = cipher.getRetCode();
|
retCode = cipherClient.getRetCode();
|
||||||
LOGGER.info("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: " + status + "============");
|
LOGGER.info(String.format("[PairWiseMask] ============= RequestExchangeKeys+GetExchangeKeys response: %s ",
|
||||||
|
"============", status));
|
||||||
if (status != FLClientStatus.SUCCESS) {
|
if (status != FLClientStatus.SUCCESS) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
// round 1
|
// round 1
|
||||||
try {
|
status = cipherClient.shareSecrets();
|
||||||
status = cipher.shareSecrets();
|
retCode = cipherClient.getRetCode();
|
||||||
retCode = cipher.getRetCode();
|
LOGGER.info(String.format("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: %s ",
|
||||||
LOGGER.info("[Encrypt] =============RequestShareSecrets+GetShareSecrets response: " + status + "=============");
|
"=============", status));
|
||||||
} catch (Exception e) {
|
|
||||||
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
|
|
||||||
status = FLClientStatus.FAILED;
|
|
||||||
}
|
|
||||||
if (status != FLClientStatus.SUCCESS) {
|
if (status != FLClientStatus.SUCCESS) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
// round2
|
// round2
|
||||||
try {
|
featureMask = cipherClient.doubleMaskingWeight();
|
||||||
featureMask = cipher.doubleMaskingWeight();
|
if (featureMask == null || featureMask.length <= 0) {
|
||||||
retCode = cipher.getRetCode();
|
LOGGER.severe(Common.addTag("[Encrypt] the returned featureMask from cipherClient.doubleMaskingWeight" +
|
||||||
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
|
" is null, please check!"));
|
||||||
} catch (Exception e) {
|
return FLClientStatus.FAILED;
|
||||||
LOGGER.severe("[PairWiseMask] catch Exception in pwCreateMask");
|
|
||||||
status = FLClientStatus.FAILED;
|
|
||||||
}
|
}
|
||||||
|
retCode = cipherClient.getRetCode();
|
||||||
|
LOGGER.info("[Encrypt] =============Create double feature mask: SUCCESS=============");
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add the pairwise mask and individual mask to model weights.
|
||||||
|
*
|
||||||
|
* @param builder the FlatBufferBuilder object used for serialization model weights.
|
||||||
|
* @param trainDataSize trainDataSize tne size of train data set.
|
||||||
|
* @return the serialized model weights after adding masks.
|
||||||
|
*/
|
||||||
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
public int[] pwMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||||
if (featureMask == null || featureMask.length == 0) {
|
if (featureMask == null || featureMask.length == 0) {
|
||||||
LOGGER.severe("[Encrypt] feature mask is null, please check");
|
LOGGER.severe("[Encrypt] feature mask is null, please check");
|
||||||
return new int[0];
|
return new int[0];
|
||||||
}
|
}
|
||||||
LOGGER.info("[Encrypt] feature mask size: " + featureMask.length);
|
LOGGER.info(String.format("[Encrypt] feature mask size: %s", featureMask.length));
|
||||||
// get feature map
|
// get feature map
|
||||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
|
@ -142,6 +194,9 @@ public class SecureProtocol {
|
||||||
} else if (flParameter.getFlName().equals(LENET)) {
|
} else if (flParameter.getFlName().equals(LENET)) {
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||||
|
} else {
|
||||||
|
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
int featureSize = encryptFeatureName.size();
|
int featureSize = encryptFeatureName.size();
|
||||||
int[] featuresMap = new int[featureSize];
|
int[] featuresMap = new int[featureSize];
|
||||||
|
@ -149,9 +204,13 @@ public class SecureProtocol {
|
||||||
for (int i = 0; i < featureSize; i++) {
|
for (int i = 0; i < featureSize; i++) {
|
||||||
String key = encryptFeatureName.get(i);
|
String key = encryptFeatureName.get(i);
|
||||||
float[] data = map.get(key);
|
float[] data = map.get(key);
|
||||||
LOGGER.info("[Encrypt] feature name: " + key + " feature size: " + data.length);
|
LOGGER.info(String.format("[Encrypt] feature name: %s feature size: %s", key, data.length));
|
||||||
for (int j = 0; j < data.length; j++) {
|
for (int j = 0; j < data.length; j++) {
|
||||||
float rawData = data[j];
|
float rawData = data[j];
|
||||||
|
if (maskIndex >= featureMask.length) {
|
||||||
|
LOGGER.severe("[Encrypt] the maskIndex is out of range for array featureMask, please check");
|
||||||
|
return new int[0];
|
||||||
|
}
|
||||||
float maskData = rawData * trainDataSize + featureMask[maskIndex];
|
float maskData = rawData * trainDataSize + featureMask[maskIndex];
|
||||||
maskIndex += 1;
|
maskIndex += 1;
|
||||||
data[j] = maskData;
|
data[j] = maskData;
|
||||||
|
@ -164,17 +223,23 @@ public class SecureProtocol {
|
||||||
return featuresMap;
|
return featuresMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reconstruct the secrets used for unmasking model weights.
|
||||||
|
*
|
||||||
|
* @return current status code in client.
|
||||||
|
*/
|
||||||
public FLClientStatus pwUnmasking() {
|
public FLClientStatus pwUnmasking() {
|
||||||
status = cipher.reconstructSecrets(); // round3
|
status = cipherClient.reconstructSecrets(); // round3
|
||||||
retCode = cipher.getRetCode();
|
retCode = cipherClient.getRetCode();
|
||||||
LOGGER.info("[Encrypt] =============GetClientList+SendReconstructSecret: " + status + "=============");
|
LOGGER.info(String.format("[Encrypt] =============GetClientList+SendReconstructSecret: %s =============",
|
||||||
|
status));
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static float calculateErf(double x) {
|
private static float calculateErf(double erfInput) {
|
||||||
double result = 0;
|
double result = 0d;
|
||||||
int segmentNum = 10000;
|
int segmentNum = 10000;
|
||||||
double deltaX = x / segmentNum;
|
double deltaX = erfInput / segmentNum;
|
||||||
result += 1;
|
result += 1;
|
||||||
for (int i = 1; i < segmentNum; i++) {
|
for (int i = 1; i < segmentNum; i++) {
|
||||||
result += 2 * Math.exp(-Math.pow(deltaX * i, 2));
|
result += 2 * Math.exp(-Math.pow(deltaX * i, 2));
|
||||||
|
@ -183,33 +248,36 @@ public class SecureProtocol {
|
||||||
return (float) (result * deltaX / Math.pow(Math.PI, 0.5));
|
return (float) (result * deltaX / Math.pow(Math.PI, 0.5));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static double calculatePhi(double t) {
|
private static double calculatePhi(double phiInput) {
|
||||||
return 0.5 * (1.0 + calculateErf((t / Math.sqrt(2.0))));
|
return 0.5 * (1.0 + calculateErf((phiInput / Math.sqrt(2.0))));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static double calculateBPositive(double eps, double s) {
|
private static double calculateBPositive(double eps, double calInput) {
|
||||||
return calculatePhi(Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
|
return calculatePhi(Math.sqrt(eps * calInput)) -
|
||||||
|
Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static double calculateBNegative(double eps, double s) {
|
private static double calculateBNegative(double eps, double calInput) {
|
||||||
return calculatePhi(-Math.sqrt(eps * s)) - Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (s + 2.0)));
|
return calculatePhi(-Math.sqrt(eps * calInput)) -
|
||||||
|
Math.exp(eps) * calculatePhi(-Math.sqrt(eps * (calInput + 2.0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static double calculateSPositive(double eps, double targetDelta, double sInf, double sSup) {
|
private static double calculateSPositive(double eps, double targetDelta, double initSInf, double initSSup) {
|
||||||
double deltaSup = calculateBPositive(eps, sSup);
|
double deltaSup = calculateBPositive(eps, initSSup);
|
||||||
|
double sInf = initSInf;
|
||||||
|
double sSup = initSSup;
|
||||||
while (deltaSup <= targetDelta) {
|
while (deltaSup <= targetDelta) {
|
||||||
sInf = sSup;
|
sInf = sSup;
|
||||||
sSup = 2 * sInf;
|
sSup = 2 * sInf;
|
||||||
deltaSup = calculateBPositive(eps, sSup);
|
deltaSup = calculateBPositive(eps, sSup);
|
||||||
}
|
}
|
||||||
|
|
||||||
double sMid = sInf + (sSup - sInf) / 2.0;
|
double sMid = sInf + (sSup - sInf) / 2.0;
|
||||||
int iterMax = 1000;
|
int iterMax = 1000;
|
||||||
int iters = 0;
|
int iters = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
double b = calculateBPositive(eps, sMid);
|
double bPositive = calculateBPositive(eps, sMid);
|
||||||
if (b <= targetDelta) {
|
if (bPositive <= targetDelta) {
|
||||||
if (targetDelta - b <= deltaError) {
|
if (targetDelta - bPositive <= deltaError) {
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
sInf = sMid;
|
sInf = sMid;
|
||||||
|
@ -226,8 +294,10 @@ public class SecureProtocol {
|
||||||
return sMid;
|
return sMid;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static double calculateSNegative(double eps, double targetDelta, double sInf, double sSup) {
|
private static double calculateSNegative(double eps, double targetDelta, double initSInf, double initSSup) {
|
||||||
double deltaSup = calculateBNegative(eps, sSup);
|
double deltaSup = calculateBNegative(eps, initSSup);
|
||||||
|
double sInf = initSInf;
|
||||||
|
double sSup = initSSup;
|
||||||
while (deltaSup > targetDelta) {
|
while (deltaSup > targetDelta) {
|
||||||
sInf = sSup;
|
sInf = sSup;
|
||||||
sSup = 2 * sInf;
|
sSup = 2 * sInf;
|
||||||
|
@ -238,9 +308,9 @@ public class SecureProtocol {
|
||||||
int iterMax = 1000;
|
int iterMax = 1000;
|
||||||
int iters = 0;
|
int iters = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
double b = calculateBNegative(eps, sMid);
|
double bNegative = calculateBNegative(eps, sMid);
|
||||||
if (b <= targetDelta) {
|
if (bNegative <= targetDelta) {
|
||||||
if (targetDelta - b <= deltaError) {
|
if (targetDelta - bNegative <= deltaError) {
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
sSup = sMid;
|
sSup = sMid;
|
||||||
|
@ -259,17 +329,26 @@ public class SecureProtocol {
|
||||||
|
|
||||||
private static double calculateSigma(double clipNorm, double eps, double targetDelta) {
|
private static double calculateSigma(double clipNorm, double eps, double targetDelta) {
|
||||||
double deltaZero = calculateBPositive(eps, 0);
|
double deltaZero = calculateBPositive(eps, 0);
|
||||||
double alpha = 1;
|
double alpha = 1d;
|
||||||
if (targetDelta > deltaZero) {
|
if (targetDelta > deltaZero) {
|
||||||
double s = calculateSPositive(eps, targetDelta, 0, 1);
|
double sPositive = calculateSPositive(eps, targetDelta, 0, 1);
|
||||||
alpha = Math.sqrt(1.0 + s / 2.0) - Math.sqrt(s / 2.0);
|
alpha = Math.sqrt(1.0 + sPositive / 2.0) - Math.sqrt(sPositive / 2.0);
|
||||||
} else if (targetDelta < deltaZero) {
|
} else if (targetDelta < deltaZero) {
|
||||||
double s = calculateSNegative(eps, targetDelta, 0, 1);
|
double sNegative = calculateSNegative(eps, targetDelta, 0, 1);
|
||||||
alpha = Math.sqrt(1.0 + s / 2.0) + Math.sqrt(s / 2.0);
|
alpha = Math.sqrt(1.0 + sNegative / 2.0) + Math.sqrt(sNegative / 2.0);
|
||||||
|
} else {
|
||||||
|
LOGGER.info(Common.addTag("[Encrypt] targetDelta = deltaZero"));
|
||||||
}
|
}
|
||||||
return alpha * clipNorm / Math.sqrt(2.0 * eps);
|
return alpha * clipNorm / Math.sqrt(2.0 * eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add differential privacy mask to model weights.
|
||||||
|
*
|
||||||
|
* @param builder the FlatBufferBuilder object used for serialization model weights.
|
||||||
|
* @param trainDataSize tne size of train data set.
|
||||||
|
* @return the serialized model weights after adding masks.
|
||||||
|
*/
|
||||||
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
public int[] dpMaskModel(FlatBufferBuilder builder, int trainDataSize) {
|
||||||
// get feature map
|
// get feature map
|
||||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||||
|
@ -279,6 +358,9 @@ public class SecureProtocol {
|
||||||
} else if (flParameter.getFlName().equals(LENET)) {
|
} else if (flParameter.getFlName().equals(LENET)) {
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||||
|
} else {
|
||||||
|
LOGGER.severe(Common.addTag("[Encrypt] the flName is not valid, only support: lenet, albert"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
Map<String, float[]> mapBeforeTrain = modelMap;
|
Map<String, float[]> mapBeforeTrain = modelMap;
|
||||||
int featureSize = encryptFeatureName.size();
|
int featureSize = encryptFeatureName.size();
|
||||||
|
@ -286,19 +368,18 @@ public class SecureProtocol {
|
||||||
double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta);
|
double gaussianSigma = calculateSigma(dpNormClip, dpEps, dpDelta);
|
||||||
LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "============="));
|
LOGGER.info(Common.addTag("[Encrypt] =============Noise sigma of DP is: " + gaussianSigma + "============="));
|
||||||
|
|
||||||
// prepare gaussian noise
|
|
||||||
SecureRandom random = new SecureRandom();
|
|
||||||
int randomInt = random.nextInt();
|
|
||||||
Random r = new Random(randomInt);
|
|
||||||
|
|
||||||
// calculate l2-norm of all layers' update array
|
// calculate l2-norm of all layers' update array
|
||||||
double updateL2Norm = 0;
|
double updateL2Norm = 0d;
|
||||||
for (int i = 0; i < featureSize; i++) {
|
for (int i = 0; i < featureSize; i++) {
|
||||||
String key = encryptFeatureName.get(i);
|
String key = encryptFeatureName.get(i);
|
||||||
float[] data = map.get(key);
|
float[] data = map.get(key);
|
||||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||||
for (int j = 0; j < data.length; j++) {
|
for (int j = 0; j < data.length; j++) {
|
||||||
float rawData = data[j];
|
float rawData = data[j];
|
||||||
|
if (j >= dataBeforeTrain.length) {
|
||||||
|
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
||||||
|
return new int[0];
|
||||||
|
}
|
||||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||||
float updateData = rawData - rawDataBeforeTrain;
|
float updateData = rawData - rawDataBeforeTrain;
|
||||||
updateL2Norm += updateData * updateData;
|
updateL2Norm += updateData * updateData;
|
||||||
|
@ -311,11 +392,26 @@ public class SecureProtocol {
|
||||||
int[] featuresMap = new int[featureSize];
|
int[] featuresMap = new int[featureSize];
|
||||||
for (int i = 0; i < featureSize; i++) {
|
for (int i = 0; i < featureSize; i++) {
|
||||||
String key = encryptFeatureName.get(i);
|
String key = encryptFeatureName.get(i);
|
||||||
|
if (!map.containsKey(key)) {
|
||||||
|
LOGGER.severe("[Encrypt] the key: " + key + " is not in map, please check!");
|
||||||
|
return new int[0];
|
||||||
|
}
|
||||||
float[] data = map.get(key);
|
float[] data = map.get(key);
|
||||||
float[] data2 = new float[data.length];
|
float[] data2 = new float[data.length];
|
||||||
|
if (!mapBeforeTrain.containsKey(key)) {
|
||||||
|
LOGGER.severe("[Encrypt] the key: " + key + " is not in mapBeforeTrain, please check!");
|
||||||
|
return new int[0];
|
||||||
|
}
|
||||||
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
float[] dataBeforeTrain = mapBeforeTrain.get(key);
|
||||||
|
|
||||||
|
// prepare gaussian noise
|
||||||
|
SecureRandom secureRandom = Common.getSecureRandom();
|
||||||
for (int j = 0; j < data.length; j++) {
|
for (int j = 0; j < data.length; j++) {
|
||||||
float rawData = data[j];
|
float rawData = data[j];
|
||||||
|
if (j >= dataBeforeTrain.length) {
|
||||||
|
LOGGER.severe("[Encrypt] the index j is out of range for array dataBeforeTrain, please check");
|
||||||
|
return new int[0];
|
||||||
|
}
|
||||||
float rawDataBeforeTrain = dataBeforeTrain[j];
|
float rawDataBeforeTrain = dataBeforeTrain[j];
|
||||||
float updateData = rawData - rawDataBeforeTrain;
|
float updateData = rawData - rawDataBeforeTrain;
|
||||||
|
|
||||||
|
@ -323,7 +419,7 @@ public class SecureProtocol {
|
||||||
updateData *= clipFactor;
|
updateData *= clipFactor;
|
||||||
|
|
||||||
// add noise
|
// add noise
|
||||||
double gaussianNoise = r.nextGaussian() * gaussianSigma;
|
double gaussianNoise = secureRandom.nextGaussian() * gaussianSigma;
|
||||||
updateData += gaussianNoise;
|
updateData += gaussianNoise;
|
||||||
data2[j] = rawDataBeforeTrain + updateData;
|
data2[j] = rawDataBeforeTrain + updateData;
|
||||||
data2[j] = data2[j] * trainDataSize;
|
data2[j] = data2[j] * trainDataSize;
|
||||||
|
@ -335,5 +431,4 @@ public class SecureProtocol {
|
||||||
}
|
}
|
||||||
return featuresMap;
|
return featuresMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,8 +13,14 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The training mode of federated learning.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public enum ServerMod {
|
public enum ServerMod {
|
||||||
FEDERATED_LEARNING,
|
FEDERATED_LEARNING,
|
||||||
HYBRID_TRAINING
|
HYBRID_TRAINING
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
* You may obtain a copy of the License at
|
* You may obtain a copy of the License at
|
||||||
|
@ -13,13 +12,20 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
import com.mindspore.flclient.model.AlInferBert;
|
import com.mindspore.flclient.model.AlInferBert;
|
||||||
import com.mindspore.flclient.model.AlTrainBert;
|
import com.mindspore.flclient.model.AlTrainBert;
|
||||||
import com.mindspore.flclient.model.SessionUtil;
|
import com.mindspore.flclient.model.SessionUtil;
|
||||||
import com.mindspore.flclient.model.TrainLenet;
|
import com.mindspore.flclient.model.TrainLenet;
|
||||||
|
|
||||||
|
import mindspore.schema.FLPlan;
|
||||||
import mindspore.schema.FeatureMap;
|
import mindspore.schema.FeatureMap;
|
||||||
import mindspore.schema.RequestFLJob;
|
import mindspore.schema.RequestFLJob;
|
||||||
import mindspore.schema.ResponseCode;
|
import mindspore.schema.ResponseCode;
|
||||||
|
@ -28,15 +34,28 @@ import mindspore.schema.ResponseFLJob;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
/**
|
||||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
* StartFLJob
|
||||||
|
*
|
||||||
|
* @since 2021-08-25
|
||||||
|
*/
|
||||||
public class StartFLJob {
|
public class StartFLJob {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString());
|
||||||
|
private static volatile StartFLJob startFLJob;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
System.loadLibrary("mindspore-lite-jni");
|
System.loadLibrary("mindspore-lite-jni");
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final Logger LOGGER = Logger.getLogger(StartFLJob.class.toString());
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
|
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||||
|
|
||||||
|
private int featureSize;
|
||||||
|
private String nextRequestTime;
|
||||||
|
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
||||||
|
|
||||||
|
private StartFLJob() {
|
||||||
|
}
|
||||||
|
|
||||||
class RequestStartFLJobBuilder {
|
class RequestStartFLJobBuilder {
|
||||||
private RequestFLJob requestFLJob;
|
private RequestFLJob requestFLJob;
|
||||||
|
@ -51,21 +70,53 @@ public class StartFLJob {
|
||||||
builder = new FlatBufferBuilder();
|
builder = new FlatBufferBuilder();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set flName
|
||||||
|
*
|
||||||
|
* @param name String
|
||||||
|
* @return RequestStartFLJobBuilder
|
||||||
|
*/
|
||||||
public RequestStartFLJobBuilder flName(String name) {
|
public RequestStartFLJobBuilder flName(String name) {
|
||||||
|
if (name == null || name.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <name> is null or empty, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
this.nameOffset = this.builder.createString(name);
|
this.nameOffset = this.builder.createString(name);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set id
|
||||||
|
*
|
||||||
|
* @param id String
|
||||||
|
* @return RequestStartFLJobBuilder
|
||||||
|
*/
|
||||||
public RequestStartFLJobBuilder id(String id) {
|
public RequestStartFLJobBuilder id(String id) {
|
||||||
|
if (id == null || id.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the parameter of <id> is null or empty, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
this.idOffset = this.builder.createString(id);
|
this.idOffset = this.builder.createString(id);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set time
|
||||||
|
*
|
||||||
|
* @param timestamp long
|
||||||
|
* @return RequestStartFLJobBuilder
|
||||||
|
*/
|
||||||
public RequestStartFLJobBuilder time(long timestamp) {
|
public RequestStartFLJobBuilder time(long timestamp) {
|
||||||
this.timestampOffset = builder.createString(String.valueOf(timestamp));
|
this.timestampOffset = builder.createString(String.valueOf(timestamp));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set dataSize
|
||||||
|
*
|
||||||
|
* @param dataSize int
|
||||||
|
* @return RequestStartFLJobBuilder
|
||||||
|
*/
|
||||||
public RequestStartFLJobBuilder dataSize(int dataSize) {
|
public RequestStartFLJobBuilder dataSize(int dataSize) {
|
||||||
// temp code need confirm
|
// temp code need confirm
|
||||||
this.dataSize = dataSize;
|
this.dataSize = dataSize;
|
||||||
|
@ -73,11 +124,22 @@ public class StartFLJob {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set iteration
|
||||||
|
*
|
||||||
|
* @param iteration iteration
|
||||||
|
* @return RequestStartFLJobBuilder
|
||||||
|
*/
|
||||||
public RequestStartFLJobBuilder iteration(int iteration) {
|
public RequestStartFLJobBuilder iteration(int iteration) {
|
||||||
this.iteration = iteration;
|
this.iteration = iteration;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* build protobuffer
|
||||||
|
*
|
||||||
|
* @return byte[] data
|
||||||
|
*/
|
||||||
public byte[] build() {
|
public byte[] build() {
|
||||||
int root = RequestFLJob.createRequestFLJob(this.builder, this.nameOffset, this.idOffset, this.iteration,
|
int root = RequestFLJob.createRequestFLJob(this.builder, this.nameOffset, this.idOffset, this.iteration,
|
||||||
this.dataSize, this.timestampOffset);
|
this.dataSize, this.timestampOffset);
|
||||||
|
@ -86,20 +148,11 @@ public class StartFLJob {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static volatile StartFLJob startFLJob;
|
/**
|
||||||
|
* getInstance of StartFLJob
|
||||||
private FLClientStatus status;
|
*
|
||||||
|
* @return StartFLJob instance
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
*/
|
||||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
|
||||||
private int featureSize;
|
|
||||||
private String nextRequestTime;
|
|
||||||
private ArrayList<String> encryptFeatureName = new ArrayList<String>();
|
|
||||||
|
|
||||||
private StartFLJob() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public static StartFLJob getInstance() {
|
public static StartFLJob getInstance() {
|
||||||
StartFLJob localRef = startFLJob;
|
StartFLJob localRef = startFLJob;
|
||||||
if (localRef == null) {
|
if (localRef == null) {
|
||||||
|
@ -117,6 +170,14 @@ public class StartFLJob {
|
||||||
return nextRequestTime;
|
return nextRequestTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get request start FLJob
|
||||||
|
*
|
||||||
|
* @param dataSize dataSize
|
||||||
|
* @param iteration iteration
|
||||||
|
* @param time time
|
||||||
|
* @return byte[] data
|
||||||
|
*/
|
||||||
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) {
|
public byte[] getRequestStartFLJob(int dataSize, int iteration, long time) {
|
||||||
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
|
RequestStartFLJobBuilder builder = new RequestStartFLJobBuilder();
|
||||||
return builder.flName(flParameter.getFlName())
|
return builder.flName(flParameter.getFlName())
|
||||||
|
@ -135,6 +196,7 @@ public class StartFLJob {
|
||||||
return encryptFeatureName;
|
return encryptFeatureName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) {
|
private FLClientStatus parseResponseAlbert(ResponseFLJob flJob) {
|
||||||
int fmCount = flJob.featureMapLength();
|
int fmCount = flJob.featureMapLength();
|
||||||
encryptFeatureName.clear();
|
encryptFeatureName.clear();
|
||||||
|
@ -149,6 +211,10 @@ public class StartFLJob {
|
||||||
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> inferFeatureMaps = new ArrayList<FeatureMap>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < fmCount; i++) {
|
||||||
FeatureMap feature = flJob.featureMap(i);
|
FeatureMap feature = flJob.featureMap(i);
|
||||||
|
if (feature == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
String featureName = feature.weightFullname();
|
String featureName = feature.weightFullname();
|
||||||
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
if (localFLParameter.getAlbertWeightName().contains(featureName)) {
|
||||||
albertFeatureMaps.add(feature);
|
albertFeatureMaps.add(feature);
|
||||||
|
@ -160,19 +226,23 @@ public class StartFLJob {
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
|
||||||
|
"weightLength: " + feature.dataLength()));
|
||||||
}
|
}
|
||||||
int tag = 0;
|
int tag = 0;
|
||||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference model-----------------"));
|
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into inference " +
|
||||||
|
"model-----------------"));
|
||||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||||
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(), inferFeatureMaps);
|
tag = SessionUtil.updateFeatures(alInferBert.getTrainSession(), flParameter.getInferModelPath(),
|
||||||
|
inferFeatureMaps);
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
|
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into train model-----------------"));
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), albertFeatureMaps);
|
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||||
|
albertFeatureMaps);
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
|
@ -182,16 +252,22 @@ public class StartFLJob {
|
||||||
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
ArrayList<FeatureMap> featureMaps = new ArrayList<FeatureMap>();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < fmCount; i++) {
|
||||||
FeatureMap feature = flJob.featureMap(i);
|
FeatureMap feature = flJob.featureMap(i);
|
||||||
|
if (feature == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
String featureName = feature.weightFullname();
|
String featureName = feature.weightFullname();
|
||||||
featureMaps.add(feature);
|
featureMaps.add(feature);
|
||||||
featureSize += feature.dataLength();
|
featureSize += feature.dataLength();
|
||||||
encryptFeatureName.add(featureName);
|
encryptFeatureName.add(featureName);
|
||||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", " +
|
||||||
|
"weightLength: " + feature.dataLength()));
|
||||||
}
|
}
|
||||||
int tag = 0;
|
int tag = 0;
|
||||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(), featureMaps);
|
tag = SessionUtil.updateFeatures(alTrainBert.getTrainSession(), flParameter.getTrainModelPath(),
|
||||||
|
featureMaps);
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
LOGGER.severe(Common.addTag("[startFLJob] unsolved error code in <SessionUtil.updateFeatures>"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
|
@ -206,11 +282,16 @@ public class StartFLJob {
|
||||||
encryptFeatureName.clear();
|
encryptFeatureName.clear();
|
||||||
for (int i = 0; i < fmCount; i++) {
|
for (int i = 0; i < fmCount; i++) {
|
||||||
FeatureMap feature = flJob.featureMap(i);
|
FeatureMap feature = flJob.featureMap(i);
|
||||||
|
if (feature == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the feature returned from server is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
String featureName = feature.weightFullname();
|
String featureName = feature.weightFullname();
|
||||||
featureMaps.add(feature);
|
featureMaps.add(feature);
|
||||||
featureSize += feature.dataLength();
|
featureSize += feature.dataLength();
|
||||||
encryptFeatureName.add(featureName);
|
encryptFeatureName.add(featureName);
|
||||||
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " + feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
LOGGER.info(Common.addTag("[startFLJob] weightFullname: " +
|
||||||
|
feature.weightFullname() + ", weightLength: " + feature.dataLength()));
|
||||||
}
|
}
|
||||||
int tag = 0;
|
int tag = 0;
|
||||||
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
LOGGER.info(Common.addTag("[startFLJob] ----------------loading weight into model-----------------"));
|
||||||
|
@ -223,7 +304,22 @@ public class StartFLJob {
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* response res
|
||||||
|
*
|
||||||
|
* @param flJob ResponseFLJob
|
||||||
|
* @return FLClientStatus
|
||||||
|
*/
|
||||||
public FLClientStatus doResponse(ResponseFLJob flJob) {
|
public FLClientStatus doResponse(ResponseFLJob flJob) {
|
||||||
|
if (flJob == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the input parameter flJob is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
FLPlan flPlanConfig = flJob.flPlanConfig();
|
||||||
|
if (flPlanConfig == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[startFLJob] the flPlanConfig is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode()));
|
LOGGER.info(Common.addTag("[startFLJob] return retCode: " + flJob.retcode()));
|
||||||
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
|
LOGGER.info(Common.addTag("[startFLJob] reason: " + flJob.reason()));
|
||||||
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
|
LOGGER.info(Common.addTag("[startFLJob] iteration: " + flJob.iteration()));
|
||||||
|
@ -236,11 +332,12 @@ public class StartFLJob {
|
||||||
|
|
||||||
switch (retCode) {
|
switch (retCode) {
|
||||||
case (ResponseCode.SUCCEED):
|
case (ResponseCode.SUCCEED):
|
||||||
localFLParameter.setServerMod(flJob.flPlanConfig().serverMode());
|
localFLParameter.setServerMod(flPlanConfig.serverMode());
|
||||||
if (ALBERT.equals(flParameter.getFlName())) {
|
if (ALBERT.equals(flParameter.getFlName())) {
|
||||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
|
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseAlbert>"));
|
||||||
status = parseResponseAlbert(flJob);
|
status = parseResponseAlbert(flJob);
|
||||||
} else if (LENET.equals(flParameter.getFlName())) {
|
}
|
||||||
|
if (LENET.equals(flParameter.getFlName())) {
|
||||||
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
|
LOGGER.info(Common.addTag("[startFLJob] into <parseResponseLenet>"));
|
||||||
status = parseResponseLenet(flJob);
|
status = parseResponseLenet(flJob);
|
||||||
}
|
}
|
||||||
|
@ -256,8 +353,4 @@ public class StartFLJob {
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus getStatus() {
|
|
||||||
return this.status;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -13,35 +13,49 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
import com.mindspore.flclient.model.AlInferBert;
|
|
||||||
import com.mindspore.flclient.model.AlTrainBert;
|
|
||||||
import com.mindspore.flclient.model.SessionUtil;
|
|
||||||
import com.mindspore.flclient.model.TrainLenet;
|
|
||||||
import mindspore.schema.ResponseGetModel;
|
|
||||||
|
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.logging.Logger;
|
|
||||||
|
|
||||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.model.AlInferBert;
|
||||||
|
import com.mindspore.flclient.model.AlTrainBert;
|
||||||
|
import com.mindspore.flclient.model.SessionUtil;
|
||||||
|
import com.mindspore.flclient.model.TrainLenet;
|
||||||
|
|
||||||
|
import mindspore.schema.ResponseGetModel;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.security.SecureRandom;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SyncFLJob defines the APIs for federated learning task.
|
||||||
|
* API flJobRun() for starting federated learning on the device, the API modelInference() for inference on the
|
||||||
|
* device, and the API getModel() for obtaining the latest model on the cloud.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class SyncFLJob {
|
public class SyncFLJob {
|
||||||
private static final Logger LOGGER = Logger.getLogger(SyncFLJob.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(SyncFLJob.class.toString());
|
||||||
|
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||||
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
|
private FLJobResultCallback flJobResultCallback = new FLJobResultCallback();
|
||||||
private Map<String, float[]> oldFeatureMap;
|
private Map<String, float[]> oldFeatureMap;
|
||||||
|
|
||||||
public SyncFLJob() {
|
/**
|
||||||
}
|
* Starts a federated learning task on the device.
|
||||||
|
*
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus flJobRun() {
|
public FLClientStatus flJobRun() {
|
||||||
|
Common.setSecureRandom(new SecureRandom());
|
||||||
localFLParameter.setFlID(flParameter.getClientID());
|
localFLParameter.setFlID(flParameter.getClientID());
|
||||||
FLLiteClient client = new FLLiteClient();
|
FLLiteClient client = new FLLiteClient();
|
||||||
FLClientStatus curStatus;
|
FLClientStatus curStatus;
|
||||||
|
@ -58,17 +72,14 @@ public class SyncFLJob {
|
||||||
if (trainDataSize <= 0) {
|
if (trainDataSize <= 0) {
|
||||||
LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0"));
|
LOGGER.severe(Common.addTag("unsolved error code in <client.setInput>: the return trainDataSize<=0"));
|
||||||
curStatus = FLClientStatus.FAILED;
|
curStatus = FLClientStatus.FAILED;
|
||||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
|
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
|
||||||
|
client.getRetCode());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
client.setTrainDataSize(trainDataSize);
|
client.setTrainDataSize(trainDataSize);
|
||||||
|
|
||||||
// startFLJob
|
// startFLJob
|
||||||
curStatus = client.startFLJob();
|
curStatus = startFLJob(client);
|
||||||
while (curStatus == FLClientStatus.WAIT) {
|
|
||||||
waitSomeTime();
|
|
||||||
curStatus = client.startFLJob();
|
|
||||||
}
|
|
||||||
if (curStatus == FLClientStatus.RESTART) {
|
if (curStatus == FLClientStatus.RESTART) {
|
||||||
restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
restart("[startFLJob]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||||
continue;
|
continue;
|
||||||
|
@ -100,11 +111,7 @@ public class SyncFLJob {
|
||||||
LOGGER.info(Common.addTag("[train] train succeed"));
|
LOGGER.info(Common.addTag("[train] train succeed"));
|
||||||
|
|
||||||
// updateModel
|
// updateModel
|
||||||
curStatus = client.updateModel();
|
curStatus = updateModel(client);
|
||||||
while (curStatus == FLClientStatus.WAIT) {
|
|
||||||
waitSomeTime();
|
|
||||||
curStatus = client.updateModel();
|
|
||||||
}
|
|
||||||
if (curStatus == FLClientStatus.RESTART) {
|
if (curStatus == FLClientStatus.RESTART) {
|
||||||
restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
restart("[updateModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||||
continue;
|
continue;
|
||||||
|
@ -124,11 +131,7 @@ public class SyncFLJob {
|
||||||
}
|
}
|
||||||
|
|
||||||
// getModel
|
// getModel
|
||||||
curStatus = client.getModel();
|
curStatus = getModel(client);
|
||||||
while (curStatus == FLClientStatus.WAIT) {
|
|
||||||
waitSomeTime();
|
|
||||||
curStatus = client.getModel();
|
|
||||||
}
|
|
||||||
if (curStatus == FLClientStatus.RESTART) {
|
if (curStatus == FLClientStatus.RESTART) {
|
||||||
restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
restart("[getModel]", client.getNextRequestTime(), client.getIteration(), client.getRetCode());
|
||||||
continue;
|
continue;
|
||||||
|
@ -142,7 +145,8 @@ public class SyncFLJob {
|
||||||
|
|
||||||
// evaluate model after getting model from server
|
// evaluate model after getting model from server
|
||||||
if (flParameter.getTestDataset().equals("null")) {
|
if (flParameter.getTestDataset().equals("null")) {
|
||||||
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the combine model"));
|
LOGGER.info(Common.addTag("[evaluate] the testDataset is null, don't evaluate the model after getting" +
|
||||||
|
" model from server"));
|
||||||
} else {
|
} else {
|
||||||
curStatus = client.evaluateModel();
|
curStatus = client.evaluateModel();
|
||||||
if (curStatus == FLClientStatus.FAILED) {
|
if (curStatus == FLClientStatus.FAILED) {
|
||||||
|
@ -151,33 +155,68 @@ public class SyncFLJob {
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[evaluate] evaluate succeed"));
|
LOGGER.info(Common.addTag("[evaluate] evaluate succeed"));
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("========================================================the total response of " + client.getIteration() + ": " + curStatus + "======================================================================"));
|
LOGGER.info(Common.addTag("========================================================the total response of "
|
||||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(), client.getRetCode());
|
+ client.getIteration() + ": " + curStatus +
|
||||||
|
"======================================================================"));
|
||||||
|
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), client.getIteration(),
|
||||||
|
client.getRetCode());
|
||||||
} while (client.getIteration() < client.getIterations());
|
} while (client.getIteration() < client.getIterations());
|
||||||
client.finalize();
|
client.freeSession();
|
||||||
LOGGER.info(Common.addTag("flJobRun finish"));
|
LOGGER.info(Common.addTag("flJobRun finish"));
|
||||||
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
|
flJobResultCallback.onFlJobFinished(flParameter.getFlName(), client.getIterations(), client.getRetCode());
|
||||||
return curStatus;
|
return curStatus;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private FLClientStatus startFLJob(FLLiteClient client) {
|
||||||
|
FLClientStatus curStatus = client.startFLJob();
|
||||||
|
while (curStatus == FLClientStatus.WAIT) {
|
||||||
|
waitSomeTime();
|
||||||
|
curStatus = client.startFLJob();
|
||||||
|
}
|
||||||
|
return curStatus;
|
||||||
|
}
|
||||||
|
|
||||||
|
private FLClientStatus updateModel(FLLiteClient client) {
|
||||||
|
FLClientStatus curStatus = client.updateModel();
|
||||||
|
while (curStatus == FLClientStatus.WAIT) {
|
||||||
|
waitSomeTime();
|
||||||
|
curStatus = client.updateModel();
|
||||||
|
}
|
||||||
|
return curStatus;
|
||||||
|
}
|
||||||
|
|
||||||
|
private FLClientStatus getModel(FLLiteClient client) {
|
||||||
|
FLClientStatus curStatus = client.getModel();
|
||||||
|
while (curStatus == FLClientStatus.WAIT) {
|
||||||
|
waitSomeTime();
|
||||||
|
curStatus = client.getModel();
|
||||||
|
}
|
||||||
|
return curStatus;
|
||||||
|
}
|
||||||
|
|
||||||
private void updateDpNormClip(FLLiteClient client) {
|
private void updateDpNormClip(FLLiteClient client) {
|
||||||
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
|
EncryptLevel encryptLevel = localFLParameter.getEncryptLevel();
|
||||||
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
|
if (encryptLevel == EncryptLevel.DP_ENCRYPT) {
|
||||||
int currentIter = client.getIteration();
|
int currentIter = client.getIteration();
|
||||||
Map<String, float[]> fedFeatureMap = getFeatureMap();
|
Map<String, float[]> fedFeatureMap = getFeatureMap();
|
||||||
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
|
float fedWeightUpdateNorm = calWeightUpdateNorm(oldFeatureMap, fedFeatureMap);
|
||||||
|
if (fedWeightUpdateNorm == -1) {
|
||||||
|
LOGGER.severe(Common.addTag("[updateDpNormClip] the returned value fedWeightUpdateNorm is not valid: " +
|
||||||
|
"-1, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
|
LOGGER.info(Common.addTag("[DP] L2-norm of weights' average update is: " + fedWeightUpdateNorm));
|
||||||
float newNormCLip = (float) client.dpNormClipFactor * fedWeightUpdateNorm;
|
float newNormCLip = (float) client.getDpNormClipFactor() * fedWeightUpdateNorm;
|
||||||
if (currentIter == 1) {
|
if (currentIter == 1) {
|
||||||
client.dpNormClipAdapt = newNormCLip;
|
client.setDpNormClipAdapt(newNormCLip);
|
||||||
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
||||||
} else {
|
} else {
|
||||||
if (newNormCLip < client.dpNormClipAdapt) {
|
if (newNormCLip < client.getDpNormClipAdapt()) {
|
||||||
client.dpNormClipAdapt = newNormCLip;
|
client.setDpNormClipAdapt(newNormCLip);
|
||||||
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
LOGGER.info(Common.addTag("[DP] dpNormClip has been updated."));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.dpNormClipAdapt));
|
LOGGER.info(Common.addTag("[DP] Adaptive dpNormClip is: " + client.getDpNormClipAdapt()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,11 +229,16 @@ public class SyncFLJob {
|
||||||
}
|
}
|
||||||
|
|
||||||
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
|
private float calWeightUpdateNorm(Map<String, float[]> originalData, Map<String, float[]> newData) {
|
||||||
float updateL2Norm = 0;
|
float updateL2Norm = 0f;
|
||||||
for (String key : originalData.keySet()) {
|
for (String key : originalData.keySet()) {
|
||||||
float[] data = originalData.get(key);
|
float[] data = originalData.get(key);
|
||||||
float[] dataAfterUpdate = newData.get(key);
|
float[] dataAfterUpdate = newData.get(key);
|
||||||
for (int j = 0; j < data.length; j++) {
|
for (int j = 0; j < data.length; j++) {
|
||||||
|
if (j >= dataAfterUpdate.length) {
|
||||||
|
LOGGER.severe("[calWeightUpdateNorm] the index j is out of range for array dataAfterUpdate, " +
|
||||||
|
"please check");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
float updateData = data[j] - dataAfterUpdate[j];
|
float updateData = data[j] - dataAfterUpdate[j];
|
||||||
updateL2Norm += updateData * updateData;
|
updateL2Norm += updateData * updateData;
|
||||||
}
|
}
|
||||||
|
@ -215,12 +259,22 @@ public class SyncFLJob {
|
||||||
return featureMap;
|
return featureMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Starts an inference task on the device.
|
||||||
|
*
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public int[] modelInference() {
|
public int[] modelInference() {
|
||||||
int[] labels = new int[0];
|
int[] labels = new int[0];
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||||
LOGGER.info(Common.addTag("===========model inference============="));
|
LOGGER.info(Common.addTag("===========model inference============="));
|
||||||
labels = alInferBert.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset(), flParameter.getVocabFile(), flParameter.getIdsFile());
|
labels = alInferBert.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset(),
|
||||||
|
flParameter.getVocabFile(), flParameter.getIdsFile());
|
||||||
|
if (labels == null || labels.length == 0) {
|
||||||
|
LOGGER.severe("[model inference] the returned label from adInferBert.inferModel() is null, please " +
|
||||||
|
"check");
|
||||||
|
}
|
||||||
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
||||||
SessionUtil.free(alInferBert.getTrainSession());
|
SessionUtil.free(alInferBert.getTrainSession());
|
||||||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||||
|
@ -228,49 +282,62 @@ public class SyncFLJob {
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
LOGGER.info(Common.addTag("===========model inference============="));
|
LOGGER.info(Common.addTag("===========model inference============="));
|
||||||
labels = trainLenet.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset().split(",")[0]);
|
labels = trainLenet.inferModel(flParameter.getInferModelPath(), flParameter.getTestDataset().split(",")[0]);
|
||||||
|
if (labels == null || labels.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
|
||||||
|
}
|
||||||
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
LOGGER.info(Common.addTag("[model inference] the predicted labels: " + Arrays.toString(labels)));
|
||||||
SessionUtil.free(trainLenet.getTrainSession());
|
SessionUtil.free(trainLenet.getTrainSession());
|
||||||
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
LOGGER.info(Common.addTag("[model inference] inference finish"));
|
||||||
}
|
}
|
||||||
if (labels.length == 0) {
|
|
||||||
LOGGER.severe(Common.addTag("[model inference] the return labels is null."));
|
|
||||||
}
|
|
||||||
return labels;
|
return labels;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Obtains the latest model on the cloud.
|
||||||
|
*
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus getModel() {
|
public FLClientStatus getModel() {
|
||||||
|
Common.setSecureRandom(Common.getFastSecureRandom());
|
||||||
int tag = 0;
|
int tag = 0;
|
||||||
FLClientStatus status = FLClientStatus.SUCCESS;
|
FLClientStatus status;
|
||||||
try {
|
try {
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
|
localFLParameter.setServerMod(ServerMod.HYBRID_TRAINING.toString());
|
||||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
|
||||||
|
flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
tag = alTrainBert.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the " +
|
||||||
|
"return is -1"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " + flParameter.getInferModelPath() + " Create inference Session============="));
|
LOGGER.info(Common.addTag("[getModel] ==========Loading inference model, " +
|
||||||
|
flParameter.getInferModelPath() + " Create inference Session============="));
|
||||||
AlInferBert alInferBert = AlInferBert.getInstance();
|
AlInferBert alInferBert = AlInferBert.getInstance();
|
||||||
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
tag = alInferBert.initSessionAndInputs(flParameter.getInferModelPath(), false);
|
||||||
} else if (flParameter.getFlName().equals(LENET)) {
|
} else if (flParameter.getFlName().equals(LENET)) {
|
||||||
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
|
localFLParameter.setServerMod(ServerMod.FEDERATED_LEARNING.toString());
|
||||||
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " + flParameter.getTrainModelPath() + " Create Train Session============="));
|
LOGGER.info(Common.addTag("[getModel] ==========Loading train model, " +
|
||||||
|
flParameter.getTrainModelPath() + " Create Train Session============="));
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
tag = trainLenet.initSessionAndInputs(flParameter.getTrainModelPath(), true);
|
||||||
}
|
}
|
||||||
if (tag == -1) {
|
if (tag == -1) {
|
||||||
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return is -1"));
|
LOGGER.severe(Common.addTag("[initSession] unsolved error code in <initSessionAndInputs>: the return " +
|
||||||
|
"is -1"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
FLCommunication flCommunication = FLCommunication.getInstance();
|
FLCommunication flCommunication = FLCommunication.getInstance();
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(),
|
||||||
|
flParameter.getDomainName());
|
||||||
GetModel getModelBuf = GetModel.getInstance();
|
GetModel getModelBuf = GetModel.getInstance();
|
||||||
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
|
byte[] buffer = getModelBuf.getRequestGetModel(flParameter.getFlName(), 0);
|
||||||
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
byte[] message = flCommunication.syncRequest(url + "/getModel", buffer);
|
||||||
if (Common.isSafeMod(message, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(message)) {
|
||||||
LOGGER.info(Common.addTag("[getModel] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[getModel] the server is not ready now, need wait some time and request " +
|
||||||
|
"again"));
|
||||||
status = FLClientStatus.WAIT;
|
status = FLClientStatus.WAIT;
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
@ -279,8 +346,8 @@ public class SyncFLJob {
|
||||||
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
|
ResponseGetModel responseDataBuf = ResponseGetModel.getRootAsResponseGetModel(debugBuffer);
|
||||||
status = getModelBuf.doResponse(responseDataBuf);
|
status = getModelBuf.doResponse(responseDataBuf);
|
||||||
LOGGER.info(Common.addTag("[getModel] success!"));
|
LOGGER.info(Common.addTag("[getModel] success!"));
|
||||||
} catch (Exception e) {
|
} catch (Exception ex) {
|
||||||
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + e.getMessage()));
|
LOGGER.severe(Common.addTag("[getModel] unsolved error code: catch Exception: " + ex.getMessage()));
|
||||||
status = FLClientStatus.FAILED;
|
status = FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
|
@ -299,19 +366,16 @@ public class SyncFLJob {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void waitSomeTime() {
|
private void waitSomeTime() {
|
||||||
if (flParameter.getSleepTime() != 0)
|
if (flParameter.getSleepTime() != 0) {
|
||||||
Common.sleep(flParameter.getSleepTime());
|
Common.sleep(flParameter.getSleepTime());
|
||||||
else
|
} else {
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void waitNextReqTime(String nextReqTime) {
|
private void waitNextReqTime(String nextReqTime) {
|
||||||
if (flParameter.isTimer()) {
|
long waitTime = Common.getWaitTime(nextReqTime);
|
||||||
long waitTime = Common.getWaitTime(nextReqTime);
|
Common.sleep(waitTime);
|
||||||
Common.sleep(waitTime);
|
|
||||||
} else {
|
|
||||||
waitSomeTime();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void restart(String tag, String nextReqTime, int iteration, int retcode) {
|
private void restart(String tag, String nextReqTime, int iteration, int retcode) {
|
||||||
|
@ -322,7 +386,8 @@ public class SyncFLJob {
|
||||||
|
|
||||||
private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) {
|
private void failed(String tag, int iteration, int retcode, FLClientStatus curStatus) {
|
||||||
LOGGER.info(Common.addTag(tag + " failed"));
|
LOGGER.info(Common.addTag(tag + " failed"));
|
||||||
LOGGER.info(Common.addTag("========================================================the total response of " + iteration + ": " + curStatus + "======================================================================"));
|
LOGGER.info(Common.addTag("=========================================the total response of " +
|
||||||
|
iteration + ": " + curStatus + "========================================="));
|
||||||
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
|
flJobResultCallback.onFlJobIterationFinished(flParameter.getFlName(), iteration, retcode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,53 +399,27 @@ public class SyncFLJob {
|
||||||
String flName = args[4];
|
String flName = args[4];
|
||||||
String trainModelPath = args[5];
|
String trainModelPath = args[5];
|
||||||
String inferModelPath = args[6];
|
String inferModelPath = args[6];
|
||||||
String clientID = args[7];
|
boolean useSSL = Boolean.parseBoolean(args[7]);
|
||||||
String ip = args[8];
|
String domainName = args[8];
|
||||||
boolean useSSL = Boolean.parseBoolean(args[9]);
|
boolean useElb = Boolean.parseBoolean(args[9]);
|
||||||
int port = Integer.parseInt(args[10]);
|
int serverNum = Integer.parseInt(args[10]);
|
||||||
int timeWindow = Integer.parseInt(args[11]);
|
String certPath = args[11];
|
||||||
boolean useElb = Boolean.parseBoolean(args[12]);
|
String task = args[12];
|
||||||
int serverNum = Integer.parseInt(args[13]);
|
|
||||||
boolean useHttps = Boolean.parseBoolean(args[14]);
|
|
||||||
String certPath = args[15];
|
|
||||||
String task = args[16];
|
|
||||||
|
|
||||||
FLParameter flParameter = FLParameter.getInstance();
|
FLParameter flParameter = FLParameter.getInstance();
|
||||||
LOGGER.info(Common.addTag("[args] trainDataset: " + trainDataset));
|
|
||||||
LOGGER.info(Common.addTag("[args] vocabFile: " + vocabFile));
|
|
||||||
LOGGER.info(Common.addTag("[args] idsFile: " + idsFile));
|
|
||||||
LOGGER.info(Common.addTag("[args] testDataset: " + testDataset));
|
|
||||||
LOGGER.info(Common.addTag("[args] flName: " + flName));
|
|
||||||
LOGGER.info(Common.addTag("[args] trainModelPath: " + trainModelPath));
|
|
||||||
LOGGER.info(Common.addTag("[args] inferModelPath: " + inferModelPath));
|
|
||||||
LOGGER.info(Common.addTag("[args] clientID: " + clientID));
|
|
||||||
LOGGER.info(Common.addTag("[args] ip: " + ip));
|
|
||||||
LOGGER.info(Common.addTag("[args] useSSL: " + useSSL));
|
|
||||||
LOGGER.info(Common.addTag("[args] port: " + port));
|
|
||||||
LOGGER.info(Common.addTag("[args] timeWindow: " + timeWindow));
|
|
||||||
LOGGER.info(Common.addTag("[args] useElb: " + useElb));
|
|
||||||
LOGGER.info(Common.addTag("[args] serverNum: " + serverNum));
|
|
||||||
LOGGER.info(Common.addTag("[args] useHttps: " + useHttps));
|
|
||||||
LOGGER.info(Common.addTag("[args] certPath: " + certPath));
|
|
||||||
LOGGER.info(Common.addTag("[args] task: " + task));
|
|
||||||
|
|
||||||
flParameter.setClientID(clientID);
|
|
||||||
SyncFLJob syncFLJob = new SyncFLJob();
|
SyncFLJob syncFLJob = new SyncFLJob();
|
||||||
if (task.equals("train")) {
|
if (task.equals("train")) {
|
||||||
flParameter.setUseHttps(useHttps);
|
|
||||||
if (useSSL) {
|
if (useSSL) {
|
||||||
flParameter.setCertPath(certPath);
|
flParameter.setCertPath(certPath);
|
||||||
}
|
}
|
||||||
flParameter.setHostName(ip);
|
|
||||||
flParameter.setTrainDataset(trainDataset);
|
flParameter.setTrainDataset(trainDataset);
|
||||||
flParameter.setFlName(flName);
|
flParameter.setFlName(flName);
|
||||||
flParameter.setTrainModelPath(trainModelPath);
|
flParameter.setTrainModelPath(trainModelPath);
|
||||||
flParameter.setTestDataset(testDataset);
|
flParameter.setTestDataset(testDataset);
|
||||||
flParameter.setInferModelPath(inferModelPath);
|
flParameter.setInferModelPath(inferModelPath);
|
||||||
flParameter.setIp(ip);
|
|
||||||
flParameter.setUseSSL(useSSL);
|
flParameter.setUseSSL(useSSL);
|
||||||
flParameter.setPort(port);
|
flParameter.setDomainName(domainName);
|
||||||
flParameter.setTimeWindow(timeWindow);
|
|
||||||
flParameter.setUseElb(useElb);
|
flParameter.setUseElb(useElb);
|
||||||
flParameter.setServerNum(serverNum);
|
flParameter.setServerNum(serverNum);
|
||||||
if (ALBERT.equals(flName)) {
|
if (ALBERT.equals(flName)) {
|
||||||
|
@ -398,17 +437,14 @@ public class SyncFLJob {
|
||||||
}
|
}
|
||||||
syncFLJob.modelInference();
|
syncFLJob.modelInference();
|
||||||
} else if (task.equals("getModel")) {
|
} else if (task.equals("getModel")) {
|
||||||
flParameter.setUseHttps(useHttps);
|
|
||||||
if (useSSL) {
|
if (useSSL) {
|
||||||
flParameter.setCertPath(certPath);
|
flParameter.setCertPath(certPath);
|
||||||
}
|
}
|
||||||
flParameter.setHostName(ip);
|
|
||||||
flParameter.setFlName(flName);
|
flParameter.setFlName(flName);
|
||||||
flParameter.setTrainModelPath(trainModelPath);
|
flParameter.setTrainModelPath(trainModelPath);
|
||||||
flParameter.setInferModelPath(inferModelPath);
|
flParameter.setInferModelPath(inferModelPath);
|
||||||
flParameter.setIp(ip);
|
|
||||||
flParameter.setUseSSL(useSSL);
|
flParameter.setUseSSL(useSSL);
|
||||||
flParameter.setPort(port);
|
flParameter.setDomainName(domainName);
|
||||||
flParameter.setUseElb(useElb);
|
flParameter.setUseElb(useElb);
|
||||||
flParameter.setServerNum(serverNum);
|
flParameter.setServerNum(serverNum);
|
||||||
syncFLJob.getModel();
|
syncFLJob.getModel();
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,10 +16,15 @@
|
||||||
|
|
||||||
package com.mindspore.flclient;
|
package com.mindspore.flclient;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
import com.mindspore.flclient.model.AlTrainBert;
|
import com.mindspore.flclient.model.AlTrainBert;
|
||||||
import com.mindspore.flclient.model.SessionUtil;
|
import com.mindspore.flclient.model.SessionUtil;
|
||||||
import com.mindspore.flclient.model.TrainLenet;
|
import com.mindspore.flclient.model.TrainLenet;
|
||||||
|
|
||||||
import mindspore.schema.FeatureMap;
|
import mindspore.schema.FeatureMap;
|
||||||
import mindspore.schema.RequestUpdateModel;
|
import mindspore.schema.RequestUpdateModel;
|
||||||
import mindspore.schema.ResponseCode;
|
import mindspore.schema.ResponseCode;
|
||||||
|
@ -31,145 +36,31 @@ import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.LocalFLParameter.ALBERT;
|
/**
|
||||||
import static com.mindspore.flclient.LocalFLParameter.LENET;
|
* Define the serialization method, handle the response message returned from server for updateModel request.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class UpdateModel {
|
public class UpdateModel {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
|
||||||
|
private static volatile UpdateModel updateModel;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
System.loadLibrary("mindspore-lite-jni");
|
System.loadLibrary("mindspore-lite-jni");
|
||||||
}
|
}
|
||||||
|
|
||||||
class RequestUpdateModelBuilder {
|
|
||||||
private RequestUpdateModel requestUM;
|
|
||||||
private FlatBufferBuilder builder;
|
|
||||||
private int fmOffset = 0;
|
|
||||||
private int nameOffset = 0;
|
|
||||||
private int idOffset = 0;
|
|
||||||
private int timestampOffset = 0;
|
|
||||||
private int iteration = 0;
|
|
||||||
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
|
|
||||||
|
|
||||||
public RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
|
|
||||||
builder = new FlatBufferBuilder();
|
|
||||||
this.encryptLevel = encryptLevel;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RequestUpdateModelBuilder flName(String name) {
|
|
||||||
this.nameOffset = this.builder.createString(name);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RequestUpdateModelBuilder time() {
|
|
||||||
Date date = new Date();
|
|
||||||
long time = date.getTime();
|
|
||||||
this.timestampOffset = builder.createString(String.valueOf(time));
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RequestUpdateModelBuilder iteration(int iteration) {
|
|
||||||
this.iteration = iteration;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RequestUpdateModelBuilder id(String id) {
|
|
||||||
this.idOffset = this.builder.createString(id);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
|
|
||||||
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
|
|
||||||
switch (encryptLevel) {
|
|
||||||
case PW_ENCRYPT:
|
|
||||||
try {
|
|
||||||
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
|
|
||||||
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
|
|
||||||
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is null, please check");
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
|
||||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
|
||||||
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
|
||||||
return this;
|
|
||||||
} catch (Exception e) {
|
|
||||||
LOGGER.severe("[Encrypt] catch error in maskModel: " + e.getMessage());
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
|
||||||
case DP_ENCRYPT:
|
|
||||||
try {
|
|
||||||
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
|
|
||||||
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
|
|
||||||
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is null, please check");
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
|
||||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
|
||||||
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
|
||||||
return this;
|
|
||||||
} catch (Exception e) {
|
|
||||||
LOGGER.severe(Common.addTag("[Encrypt] catch error in maskModel: " + e.getMessage()));
|
|
||||||
throw new RuntimeException();
|
|
||||||
}
|
|
||||||
case NOT_ENCRYPT:
|
|
||||||
default:
|
|
||||||
Map<String, float[]> map = new HashMap<String, float[]>();
|
|
||||||
if (flParameter.getFlName().equals(ALBERT)) {
|
|
||||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
|
|
||||||
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
|
||||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
|
||||||
if (map.isEmpty()) {
|
|
||||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
|
|
||||||
status = FLClientStatus.FAILED;
|
|
||||||
}
|
|
||||||
} else if (flParameter.getFlName().equals(LENET)) {
|
|
||||||
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " + flParameter.getFlName()));
|
|
||||||
TrainLenet trainLenet = TrainLenet.getInstance();
|
|
||||||
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
|
||||||
if (map.isEmpty()) {
|
|
||||||
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil.convertTensorToFeatures>"));
|
|
||||||
status = FLClientStatus.FAILED;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int featureSize = encryptFeatureName.size();
|
|
||||||
int[] fmOffsets = new int[featureSize];
|
|
||||||
for (int i = 0; i < featureSize; i++) {
|
|
||||||
String key = encryptFeatureName.get(i);
|
|
||||||
float[] data = map.get(key);
|
|
||||||
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature size: " + data.length));
|
|
||||||
for (int j = 0; j < data.length; j++) {
|
|
||||||
float rawData = data[j];
|
|
||||||
data[j] = data[j] * trainDataSize;
|
|
||||||
}
|
|
||||||
int featureName = builder.createString(key);
|
|
||||||
int weight = FeatureMap.createDataVector(builder, data);
|
|
||||||
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
|
||||||
fmOffsets[i] = featureMap;
|
|
||||||
}
|
|
||||||
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public byte[] build() {
|
|
||||||
RequestUpdateModel.startRequestUpdateModel(this.builder);
|
|
||||||
RequestUpdateModel.addFlName(builder, nameOffset);
|
|
||||||
RequestUpdateModel.addFlId(this.builder, idOffset);
|
|
||||||
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
|
|
||||||
RequestUpdateModel.addIteration(builder, this.iteration);
|
|
||||||
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
|
|
||||||
int root = RequestUpdateModel.endRequestUpdateModel(builder);
|
|
||||||
builder.finish(root);
|
|
||||||
return builder.sizedByteArray();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final Logger LOGGER = Logger.getLogger(UpdateModel.class.toString());
|
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||||
private String nextRequestTime;
|
|
||||||
private FLClientStatus status;
|
private FLClientStatus status;
|
||||||
private static volatile UpdateModel updateModel;
|
|
||||||
|
|
||||||
private UpdateModel() {
|
private UpdateModel() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the singleton object of the class UpdateModel.
|
||||||
|
*
|
||||||
|
* @return the singleton object of the class UpdateModel.
|
||||||
|
*/
|
||||||
public static UpdateModel getInstance() {
|
public static UpdateModel getInstance() {
|
||||||
UpdateModel localRef = updateModel;
|
UpdateModel localRef = updateModel;
|
||||||
if (localRef == null) {
|
if (localRef == null) {
|
||||||
|
@ -183,25 +74,35 @@ public class UpdateModel {
|
||||||
return localRef;
|
return localRef;
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getNextRequestTime() {
|
|
||||||
return nextRequestTime;
|
|
||||||
}
|
|
||||||
|
|
||||||
public FLClientStatus getStatus() {
|
public FLClientStatus getStatus() {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a flatBuffer builder of RequestUpdateModel.
|
||||||
|
*
|
||||||
|
* @param iteration current iteration of federated learning task.
|
||||||
|
* @param secureProtocol the object that defines encryption and decryption methods.
|
||||||
|
* @param trainDataSize the size of train date set.
|
||||||
|
* @return the flatBuffer builder of RequestUpdateModel in byte[] format.
|
||||||
|
*/
|
||||||
public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) {
|
public byte[] getRequestUpdateFLJob(int iteration, SecureProtocol secureProtocol, int trainDataSize) {
|
||||||
RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel());
|
RequestUpdateModelBuilder builder = new RequestUpdateModelBuilder(localFLParameter.getEncryptLevel());
|
||||||
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID()).featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
|
return builder.flName(flParameter.getFlName()).time().id(localFLParameter.getFlID())
|
||||||
|
.featuresMap(secureProtocol, trainDataSize).iteration(iteration).build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle the response message returned from server.
|
||||||
|
*
|
||||||
|
* @param response the response message returned from server.
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
public FLClientStatus doResponse(ResponseUpdateModel response) {
|
public FLClientStatus doResponse(ResponseUpdateModel response) {
|
||||||
LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================"));
|
LOGGER.info(Common.addTag("[updateModel] ==========updateModel response================"));
|
||||||
LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode()));
|
LOGGER.info(Common.addTag("[updateModel] ==========retcode: " + response.retcode()));
|
||||||
LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason()));
|
LOGGER.info(Common.addTag("[updateModel] ==========reason: " + response.reason()));
|
||||||
LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime()));
|
LOGGER.info(Common.addTag("[updateModel] ==========next request time: " + response.nextReqTime()));
|
||||||
nextRequestTime = response.nextReqTime();
|
|
||||||
switch (response.retcode()) {
|
switch (response.retcode()) {
|
||||||
case (ResponseCode.SUCCEED):
|
case (ResponseCode.SUCCEED):
|
||||||
LOGGER.info(Common.addTag("[updateModel] updateModel success"));
|
LOGGER.info(Common.addTag("[updateModel] updateModel success"));
|
||||||
|
@ -213,8 +114,165 @@ public class UpdateModel {
|
||||||
LOGGER.warning(Common.addTag("[updateModel] catch RequestError or SystemError"));
|
LOGGER.warning(Common.addTag("[updateModel] catch RequestError or SystemError"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
default:
|
default:
|
||||||
LOGGER.severe(Common.addTag("[updateModel]the return <retcode> from server is invalid: " + response.retcode()));
|
LOGGER.severe(Common.addTag("[updateModel]the return <retCode> from server is invalid: " +
|
||||||
|
response.retcode()));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class RequestUpdateModelBuilder {
|
||||||
|
private RequestUpdateModel requestUM;
|
||||||
|
private FlatBufferBuilder builder;
|
||||||
|
private int fmOffset = 0;
|
||||||
|
private int nameOffset = 0;
|
||||||
|
private int idOffset = 0;
|
||||||
|
private int timestampOffset = 0;
|
||||||
|
private int iteration = 0;
|
||||||
|
private EncryptLevel encryptLevel = EncryptLevel.NOT_ENCRYPT;
|
||||||
|
|
||||||
|
private RequestUpdateModelBuilder(EncryptLevel encryptLevel) {
|
||||||
|
builder = new FlatBufferBuilder();
|
||||||
|
this.encryptLevel = encryptLevel;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Serialize the element flName in RequestUpdateModel.
|
||||||
|
*
|
||||||
|
* @param name the model name.
|
||||||
|
* @return the RequestUpdateModelBuilder object.
|
||||||
|
*/
|
||||||
|
private RequestUpdateModelBuilder flName(String name) {
|
||||||
|
if (name == null || name.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[updateModel] the parameter of <name> is null or empty, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.nameOffset = this.builder.createString(name);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Serialize the element timestamp in RequestUpdateModel.
|
||||||
|
*
|
||||||
|
* @return the RequestUpdateModelBuilder object.
|
||||||
|
*/
|
||||||
|
private RequestUpdateModelBuilder time() {
|
||||||
|
Date date = new Date();
|
||||||
|
long time = date.getTime();
|
||||||
|
this.timestampOffset = builder.createString(String.valueOf(time));
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Serialize the element iteration in RequestUpdateModel.
|
||||||
|
*
|
||||||
|
* @param iteration current iteration of federated learning task.
|
||||||
|
* @return the RequestUpdateModelBuilder object.
|
||||||
|
*/
|
||||||
|
private RequestUpdateModelBuilder iteration(int iteration) {
|
||||||
|
this.iteration = iteration;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Serialize the element fl_id in RequestUpdateModel.
|
||||||
|
*
|
||||||
|
* @param id a number that uniquely identifies a client.
|
||||||
|
* @return the RequestUpdateModelBuilder object.
|
||||||
|
*/
|
||||||
|
private RequestUpdateModelBuilder id(String id) {
|
||||||
|
if (id == null || id.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[updateModel] the parameter of <id> is null or empty, please check!"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.idOffset = this.builder.createString(id);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private RequestUpdateModelBuilder featuresMap(SecureProtocol secureProtocol, int trainDataSize) {
|
||||||
|
ArrayList<String> encryptFeatureName = secureProtocol.getEncryptFeatureName();
|
||||||
|
switch (encryptLevel) {
|
||||||
|
case PW_ENCRYPT:
|
||||||
|
int[] fmOffsetsPW = secureProtocol.pwMaskModel(builder, trainDataSize);
|
||||||
|
if (fmOffsetsPW == null || fmOffsetsPW.length == 0) {
|
||||||
|
LOGGER.severe("[Encrypt] the return fmOffsetsPW from <secureProtocol.pwMaskModel> is " +
|
||||||
|
"null, please check");
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsPW);
|
||||||
|
LOGGER.info(Common.addTag("[Encrypt] pairwise mask model ok!"));
|
||||||
|
return this;
|
||||||
|
case DP_ENCRYPT:
|
||||||
|
int[] fmOffsetsDP = secureProtocol.dpMaskModel(builder, trainDataSize);
|
||||||
|
if (fmOffsetsDP == null || fmOffsetsDP.length == 0) {
|
||||||
|
LOGGER.severe("[Encrypt] the return fmOffsetsDP from <secureProtocol.dpMaskModel> is " +
|
||||||
|
"null, please check");
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsetsDP);
|
||||||
|
LOGGER.info(Common.addTag("[Encrypt] DP mask model ok!"));
|
||||||
|
return this;
|
||||||
|
case NOT_ENCRYPT:
|
||||||
|
default:
|
||||||
|
Map<String, float[]> map = new HashMap<String, float[]>();
|
||||||
|
if (flParameter.getFlName().equals(ALBERT)) {
|
||||||
|
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
|
||||||
|
flParameter.getFlName()));
|
||||||
|
AlTrainBert alTrainBert = AlTrainBert.getInstance();
|
||||||
|
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(alTrainBert.getTrainSession()));
|
||||||
|
if (map.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
|
||||||
|
".convertTensorToFeatures>"));
|
||||||
|
status = FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
} else if (flParameter.getFlName().equals(LENET)) {
|
||||||
|
LOGGER.info(Common.addTag("[updateModel] serialize feature map for " +
|
||||||
|
flParameter.getFlName()));
|
||||||
|
TrainLenet trainLenet = TrainLenet.getInstance();
|
||||||
|
map = SessionUtil.convertTensorToFeatures(SessionUtil.getFeatures(trainLenet.getTrainSession()));
|
||||||
|
if (map.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[updateModel] the return map is empty in <SessionUtil" +
|
||||||
|
".convertTensorToFeatures>"));
|
||||||
|
status = FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOGGER.severe(Common.addTag("[updateModel] the flName is not valid"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
int featureSize = encryptFeatureName.size();
|
||||||
|
int[] fmOffsets = new int[featureSize];
|
||||||
|
for (int i = 0; i < featureSize; i++) {
|
||||||
|
String key = encryptFeatureName.get(i);
|
||||||
|
float[] data = map.get(key);
|
||||||
|
LOGGER.info(Common.addTag("[updateModel build featuresMap] feature name: " + key + " feature " +
|
||||||
|
"size: " + data.length));
|
||||||
|
for (int j = 0; j < data.length; j++) {
|
||||||
|
data[j] = data[j] * trainDataSize;
|
||||||
|
}
|
||||||
|
int featureName = builder.createString(key);
|
||||||
|
int weight = FeatureMap.createDataVector(builder, data);
|
||||||
|
int featureMap = FeatureMap.createFeatureMap(builder, featureName, weight);
|
||||||
|
fmOffsets[i] = featureMap;
|
||||||
|
}
|
||||||
|
this.fmOffset = RequestUpdateModel.createFeatureMapVector(builder, fmOffsets);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a flatBuffer builder of RequestUpdateModel.
|
||||||
|
*
|
||||||
|
* @return the flatBuffer builder of RequestUpdateModel in byte[] format.
|
||||||
|
*/
|
||||||
|
private byte[] build() {
|
||||||
|
RequestUpdateModel.startRequestUpdateModel(this.builder);
|
||||||
|
RequestUpdateModel.addFlName(builder, nameOffset);
|
||||||
|
RequestUpdateModel.addFlId(this.builder, idOffset);
|
||||||
|
RequestUpdateModel.addTimestamp(builder, this.timestampOffset);
|
||||||
|
RequestUpdateModel.addIteration(builder, this.iteration);
|
||||||
|
RequestUpdateModel.addFeatureMap(builder, this.fmOffset);
|
||||||
|
int root = RequestUpdateModel.endRequestUpdateModel(builder);
|
||||||
|
builder.finish(root);
|
||||||
|
return builder.sizedByteArray();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,21 +16,33 @@
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher;
|
package com.mindspore.flclient.cipher;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.I_VEC_LEN;
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
|
||||||
|
|
||||||
import com.mindspore.flclient.Common;
|
import com.mindspore.flclient.Common;
|
||||||
|
|
||||||
import javax.crypto.Cipher;
|
|
||||||
import javax.crypto.spec.IvParameterSpec;
|
|
||||||
import javax.crypto.spec.SecretKeySpec;
|
|
||||||
import java.io.UnsupportedEncodingException;
|
import java.io.UnsupportedEncodingException;
|
||||||
|
import java.security.InvalidAlgorithmParameterException;
|
||||||
|
import java.security.InvalidKeyException;
|
||||||
|
import java.security.NoSuchAlgorithmException;
|
||||||
|
import java.security.SecureRandom;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
import javax.crypto.BadPaddingException;
|
||||||
|
import javax.crypto.Cipher;
|
||||||
|
import javax.crypto.IllegalBlockSizeException;
|
||||||
|
import javax.crypto.NoSuchPaddingException;
|
||||||
|
import javax.crypto.spec.IvParameterSpec;
|
||||||
|
import javax.crypto.spec.SecretKeySpec;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define encryption and decryption methods.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class AESEncrypt {
|
public class AESEncrypt {
|
||||||
private static final Logger LOGGER = Logger.getLogger(AESEncrypt.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(AESEncrypt.class.toString());
|
||||||
/**
|
|
||||||
* 128, 192 or 256
|
|
||||||
*/
|
|
||||||
private static final int KEY_SIZE = 256;
|
|
||||||
private static final int I_VEC_LEN = 16;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* encrypt/decrypt algorithm name
|
* encrypt/decrypt algorithm name
|
||||||
|
@ -43,62 +55,145 @@ public class AESEncrypt {
|
||||||
private static final String CIPHER_MODE_CTR = "AES/CTR/NoPadding";
|
private static final String CIPHER_MODE_CTR = "AES/CTR/NoPadding";
|
||||||
private static final String CIPHER_MODE_CBC = "AES/CBC/PKCS5PADDING";
|
private static final String CIPHER_MODE_CBC = "AES/CBC/PKCS5PADDING";
|
||||||
|
|
||||||
private String CIPHER_MODE;
|
private String cipherMod;
|
||||||
|
|
||||||
private static final int RANDOM_LEN = KEY_SIZE / 8;
|
/**
|
||||||
|
* Defining a Constructor of the class AESEncrypt.
|
||||||
private String iVecS = "1111111111111111";
|
*
|
||||||
private byte[] iVec = iVecS.getBytes("utf-8");
|
* @param key the Key.
|
||||||
|
* @param mode the encryption Mode.
|
||||||
public AESEncrypt(byte[] key, byte[] iVecIn, String mode) throws UnsupportedEncodingException {
|
*/
|
||||||
|
public AESEncrypt(byte[] key, String mode) {
|
||||||
if (key == null) {
|
if (key == null) {
|
||||||
LOGGER.severe(Common.addTag("Key is null"));
|
LOGGER.severe(Common.addTag("Key is null"));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (key.length != KEY_SIZE / 8) {
|
if (key.length != KEY_LEN) {
|
||||||
LOGGER.severe(Common.addTag("the length of key is not correct"));
|
LOGGER.severe(Common.addTag("the length of key is not correct"));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mode.contains("CBC")) {
|
if (mode.contains("CBC")) {
|
||||||
CIPHER_MODE = CIPHER_MODE_CBC;
|
cipherMod = CIPHER_MODE_CBC;
|
||||||
} else if (mode.contains("CTR")) {
|
} else if (mode.contains("CTR")) {
|
||||||
CIPHER_MODE = CIPHER_MODE_CTR;
|
cipherMod = CIPHER_MODE_CTR;
|
||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (iVecIn == null || iVecIn.length != I_VEC_LEN) {
|
}
|
||||||
return;
|
|
||||||
|
/**
|
||||||
|
* Defining the CBC encryption Mode.
|
||||||
|
*
|
||||||
|
* @param key the Key.
|
||||||
|
* @param data the data to be encrypted.
|
||||||
|
* @return the data to be encrypted.
|
||||||
|
*/
|
||||||
|
public byte[] encrypt(byte[] key, byte[] data) {
|
||||||
|
if (key == null) {
|
||||||
|
LOGGER.severe(Common.addTag("Key is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
if (data == null) {
|
||||||
|
LOGGER.severe(Common.addTag("data is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
byte[] iVec = new byte[I_VEC_LEN];
|
||||||
|
SecureRandom secureRandom = Common.getSecureRandom();
|
||||||
|
secureRandom.nextBytes(iVec);
|
||||||
|
|
||||||
|
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||||
|
Cipher cipher = Cipher.getInstance(cipherMod);
|
||||||
|
|
||||||
|
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||||
|
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
||||||
|
byte[] encrypted = cipher.doFinal(data);
|
||||||
|
byte[] encryptedAddIv = new byte[encrypted.length + iVec.length];
|
||||||
|
System.arraycopy(iVec, 0, encryptedAddIv, 0, iVec.length);
|
||||||
|
System.arraycopy(encrypted, 0, encryptedAddIv, iVec.length, encrypted.length);
|
||||||
|
return encryptedAddIv;
|
||||||
|
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
|
||||||
|
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
|
||||||
|
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException or " +
|
||||||
|
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
|
||||||
|
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
|
||||||
|
return new byte[0];
|
||||||
}
|
}
|
||||||
iVec = iVecIn;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public byte[] encrypt(byte[] key, byte[] data) throws Exception {
|
/**
|
||||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
* Defining the CTR encryption Mode.
|
||||||
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
|
*
|
||||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
* @param key the Key.
|
||||||
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
* @param data the data to be encrypted.
|
||||||
byte[] encrypted = cipher.doFinal(data);
|
* @param iVec the IV value.
|
||||||
String encryptResultStr = BaseUtil.byte2HexString(encrypted);
|
* @return the data to be encrypted.
|
||||||
return encrypted;
|
*/
|
||||||
|
public byte[] encryptCTR(byte[] key, byte[] data, byte[] iVec) {
|
||||||
|
if (key == null) {
|
||||||
|
LOGGER.severe(Common.addTag("Key is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
if (data == null) {
|
||||||
|
LOGGER.severe(Common.addTag("data is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
if (iVec == null || iVec.length != I_VEC_LEN) {
|
||||||
|
LOGGER.severe(Common.addTag("iVec is null or the length of iVec is not valid, it should be " + "I_VEC_LEN"
|
||||||
|
));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||||
|
Cipher cipher = Cipher.getInstance(cipherMod);
|
||||||
|
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||||
|
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
||||||
|
return cipher.doFinal(data);
|
||||||
|
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
|
||||||
|
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
|
||||||
|
LOGGER.severe(Common.addTag("[encryptCTR] catch NoSuchAlgorithmException or " +
|
||||||
|
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
|
||||||
|
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public byte[] encryptCTR(byte[] key, byte[] data) throws Exception {
|
/**
|
||||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
* Defining the decrypt method.
|
||||||
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
|
*
|
||||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
* @param key the Key.
|
||||||
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, iv);
|
* @param encryptDataAddIv the data to be decrypted.
|
||||||
byte[] encrypted = cipher.doFinal(data);
|
* @return the data to be decrypted.
|
||||||
return encrypted;
|
*/
|
||||||
|
public byte[] decrypt(byte[] key, byte[] encryptDataAddIv) {
|
||||||
|
if (key == null) {
|
||||||
|
LOGGER.severe(Common.addTag("Key is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
if (encryptDataAddIv == null) {
|
||||||
|
LOGGER.severe(Common.addTag("encryptDataAddIv is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
if (encryptDataAddIv.length <= I_VEC_LEN) {
|
||||||
|
LOGGER.severe(Common.addTag("the length of encryptDataAddIv is not valid: " + encryptDataAddIv.length +
|
||||||
|
", it should be > " + I_VEC_LEN));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
byte[] iVec = Arrays.copyOfRange(encryptDataAddIv, 0, I_VEC_LEN);
|
||||||
|
byte[] encryptData = Arrays.copyOfRange(encryptDataAddIv, I_VEC_LEN, encryptDataAddIv.length);
|
||||||
|
SecretKeySpec sKeySpec = new SecretKeySpec(key, ALGORITHM);
|
||||||
|
Cipher cipher = Cipher.getInstance(cipherMod);
|
||||||
|
IvParameterSpec iv = new IvParameterSpec(iVec);
|
||||||
|
cipher.init(Cipher.DECRYPT_MODE, sKeySpec, iv);
|
||||||
|
return cipher.doFinal(encryptData);
|
||||||
|
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException |
|
||||||
|
InvalidAlgorithmParameterException | IllegalBlockSizeException | BadPaddingException ex) {
|
||||||
|
LOGGER.severe(Common.addTag("catch NoSuchAlgorithmException or " +
|
||||||
|
"NoSuchPaddingException or InvalidKeyException or InvalidAlgorithmParameterException or " +
|
||||||
|
"IllegalBlockSizeException or BadPaddingException: " + ex.getMessage()));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public byte[] decrypt(byte[] key, byte[] encryptData) throws Exception {
|
|
||||||
SecretKeySpec skeySpec = new SecretKeySpec(key, ALGORITHM);
|
|
||||||
Cipher cipher = Cipher.getInstance(CIPHER_MODE);
|
|
||||||
IvParameterSpec iv = new IvParameterSpec(iVec);
|
|
||||||
cipher.init(Cipher.DECRYPT_MODE, skeySpec, iv);
|
|
||||||
byte[] origin = cipher.doFinal(encryptData);
|
|
||||||
return origin;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,18 +1,19 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
* <p>
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
* You may obtain a copy of the License at
|
* You may obtain a copy of the License at
|
||||||
* <p>
|
*
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
* <p>
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher;
|
package com.mindspore.flclient.cipher;
|
||||||
|
|
||||||
import java.io.UnsupportedEncodingException;
|
import java.io.UnsupportedEncodingException;
|
||||||
|
@ -21,14 +22,23 @@ import java.nio.charset.Charset;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define conversion methods between basic data types.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class BaseUtil {
|
public class BaseUtil {
|
||||||
private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
|
private static final char[] HEX_DIGITS = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B',
|
||||||
|
'C', 'D', 'E', 'F'};
|
||||||
public BaseUtil() {
|
|
||||||
}
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert byte[] to String in hexadecimal format.
|
||||||
|
*
|
||||||
|
* @param bytes the byte[] object.
|
||||||
|
* @return the String object converted from byte[].
|
||||||
|
*/
|
||||||
public static String byte2HexString(byte[] bytes) {
|
public static String byte2HexString(byte[] bytes) {
|
||||||
if (null == bytes) {
|
if (bytes == null) {
|
||||||
return null;
|
return null;
|
||||||
} else if (bytes.length == 0) {
|
} else if (bytes.length == 0) {
|
||||||
return "";
|
return "";
|
||||||
|
@ -36,14 +46,20 @@ public class BaseUtil {
|
||||||
char[] chars = new char[bytes.length * 2];
|
char[] chars = new char[bytes.length * 2];
|
||||||
|
|
||||||
for (int i = 0; i < bytes.length; ++i) {
|
for (int i = 0; i < bytes.length; ++i) {
|
||||||
int b = bytes[i];
|
int byteNum = bytes[i];
|
||||||
chars[i * 2] = HEX_DIGITS[(b & 240) >> 4];
|
chars[i * 2] = HEX_DIGITS[(byteNum & 240) >> 4];
|
||||||
chars[i * 2 + 1] = HEX_DIGITS[b & 15];
|
chars[i * 2 + 1] = HEX_DIGITS[byteNum & 15];
|
||||||
}
|
}
|
||||||
return new String(chars);
|
return new String(chars);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert String in hexadecimal format to byte[].
|
||||||
|
*
|
||||||
|
* @param str the String object.
|
||||||
|
* @return the byte[] converted from String object.
|
||||||
|
*/
|
||||||
public static byte[] hexString2ByteArray(String str) {
|
public static byte[] hexString2ByteArray(String str) {
|
||||||
int length = str.length() / 2;
|
int length = str.length() / 2;
|
||||||
byte[] bytes = new byte[length];
|
byte[] bytes = new byte[length];
|
||||||
|
@ -58,8 +74,13 @@ public class BaseUtil {
|
||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert byte[] to BigInteger.
|
||||||
|
*
|
||||||
|
* @param bytes the byte[] object.
|
||||||
|
* @return the BigInteger object converted from byte[].
|
||||||
|
*/
|
||||||
public static BigInteger byteArray2BigInteger(byte[] bytes) {
|
public static BigInteger byteArray2BigInteger(byte[] bytes) {
|
||||||
|
|
||||||
BigInteger bigInteger = BigInteger.ZERO;
|
BigInteger bigInteger = BigInteger.ZERO;
|
||||||
for (int i = 0; i < bytes.length; ++i) {
|
for (int i = 0; i < bytes.length; ++i) {
|
||||||
int intI = bytes[i];
|
int intI = bytes[i];
|
||||||
|
@ -72,6 +93,13 @@ public class BaseUtil {
|
||||||
return bigInteger;
|
return bigInteger;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert String to BigInteger.
|
||||||
|
*
|
||||||
|
* @param str the String object.
|
||||||
|
* @return the BigInteger object converted from String object.
|
||||||
|
* @throws UnsupportedEncodingException if the encoding is not supported.
|
||||||
|
*/
|
||||||
public static BigInteger string2BigInteger(String str) throws UnsupportedEncodingException {
|
public static BigInteger string2BigInteger(String str) throws UnsupportedEncodingException {
|
||||||
StringBuilder res = new StringBuilder();
|
StringBuilder res = new StringBuilder();
|
||||||
byte[] bytes = String.valueOf(str).getBytes("UTF-8");
|
byte[] bytes = String.valueOf(str).getBytes("UTF-8");
|
||||||
|
@ -83,14 +111,20 @@ public class BaseUtil {
|
||||||
return bigInteger;
|
return bigInteger;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String bigInteger2String(BigInteger bigInteger) throws UnsupportedEncodingException {
|
/**
|
||||||
|
* Convert BigInteger to String.
|
||||||
|
*
|
||||||
|
* @param bigInteger the BigInteger object.
|
||||||
|
* @return the String object converted from BigInteger.
|
||||||
|
*/
|
||||||
|
public static String bigInteger2String(BigInteger bigInteger) {
|
||||||
StringBuilder res = new StringBuilder();
|
StringBuilder res = new StringBuilder();
|
||||||
List<Integer> lists = new ArrayList<>();
|
List<Integer> lists = new ArrayList<>();
|
||||||
BigInteger bi = bigInteger;
|
BigInteger bi = bigInteger;
|
||||||
BigInteger DIV = BigInteger.valueOf(256);
|
BigInteger div = BigInteger.valueOf(256);
|
||||||
while (bi.compareTo(BigInteger.ZERO) > 0) {
|
while (bi.compareTo(BigInteger.ZERO) > 0) {
|
||||||
lists.add(bi.mod(DIV).intValue());
|
lists.add(bi.mod(div).intValue());
|
||||||
bi = bi.divide(DIV);
|
bi = bi.divide(div);
|
||||||
}
|
}
|
||||||
for (int i = lists.size() - 1; i >= 0; --i) {
|
for (int i = lists.size() - 1; i >= 0; --i) {
|
||||||
res.append((char) (int) (lists.get(i)));
|
res.append((char) (int) (lists.get(i)));
|
||||||
|
@ -98,13 +132,19 @@ public class BaseUtil {
|
||||||
return res.toString();
|
return res.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static byte[] bigInteger2byteArray(BigInteger bigInteger) throws UnsupportedEncodingException {
|
/**
|
||||||
|
* Convert BigInteger to byte[].
|
||||||
|
*
|
||||||
|
* @param bigInteger the BigInteger object.
|
||||||
|
* @return the byte[] object converted from BigInteger.
|
||||||
|
*/
|
||||||
|
public static byte[] bigInteger2byteArray(BigInteger bigInteger) {
|
||||||
List<Integer> lists = new ArrayList<>();
|
List<Integer> lists = new ArrayList<>();
|
||||||
BigInteger bi = bigInteger;
|
BigInteger bi = bigInteger;
|
||||||
BigInteger DIV = BigInteger.valueOf(256);
|
BigInteger div = BigInteger.valueOf(256);
|
||||||
while (bi.compareTo(BigInteger.ZERO) > 0) {
|
while (bi.compareTo(BigInteger.ZERO) > 0) {
|
||||||
lists.add(bi.mod(DIV).intValue());
|
lists.add(bi.mod(div).intValue());
|
||||||
bi = bi.divide(DIV);
|
bi = bi.divide(div);
|
||||||
}
|
}
|
||||||
byte[] res = new byte[lists.size()];
|
byte[] res = new byte[lists.size()];
|
||||||
for (int i = lists.size() - 1; i >= 0; --i) {
|
for (int i = lists.size() - 1; i >= 0; --i) {
|
||||||
|
@ -113,13 +153,19 @@ public class BaseUtil {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert Integer to byte[].
|
||||||
|
*
|
||||||
|
* @param num the Integer object.
|
||||||
|
* @return the byte[] object converted from Integer.
|
||||||
|
*/
|
||||||
public static byte[] integer2byteArray(Integer num) {
|
public static byte[] integer2byteArray(Integer num) {
|
||||||
List<Integer> lists = new ArrayList<>();
|
List<Integer> lists = new ArrayList<>();
|
||||||
Integer bi = num;
|
Integer bi = num;
|
||||||
Integer DIV = 256;
|
Integer div = 256;
|
||||||
while (bi > 0) {
|
while (bi > 0) {
|
||||||
lists.add(bi % DIV);
|
lists.add(bi % div);
|
||||||
bi = bi / DIV;
|
bi = bi / div;
|
||||||
}
|
}
|
||||||
byte[] res = new byte[lists.size()];
|
byte[] res = new byte[lists.size()];
|
||||||
for (int i = lists.size() - 1; i >= 0; --i) {
|
for (int i = lists.size() - 1; i >= 0; --i) {
|
||||||
|
@ -128,8 +174,13 @@ public class BaseUtil {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert byte[] to Integer.
|
||||||
|
*
|
||||||
|
* @param bytes the byte[] object.
|
||||||
|
* @return the Integer object converted from byte[].
|
||||||
|
*/
|
||||||
public static Integer byteArray2Integer(byte[] bytes) {
|
public static Integer byteArray2Integer(byte[] bytes) {
|
||||||
|
|
||||||
Integer num = 0;
|
Integer num = 0;
|
||||||
for (int i = 0; i < bytes.length; ++i) {
|
for (int i = 0; i < bytes.length; ++i) {
|
||||||
int intI = bytes[i];
|
int intI = bytes[i];
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -14,9 +14,13 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher;
|
package com.mindspore.flclient.cipher;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
import com.mindspore.flclient.Common;
|
import com.mindspore.flclient.Common;
|
||||||
import com.mindspore.flclient.FLClientStatus;
|
import com.mindspore.flclient.FLClientStatus;
|
||||||
import com.mindspore.flclient.FLCommunication;
|
import com.mindspore.flclient.FLCommunication;
|
||||||
|
@ -25,23 +29,27 @@ import com.mindspore.flclient.LocalFLParameter;
|
||||||
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
||||||
import com.mindspore.flclient.cipher.struct.EncryptShare;
|
import com.mindspore.flclient.cipher.struct.EncryptShare;
|
||||||
import com.mindspore.flclient.cipher.struct.NewArray;
|
import com.mindspore.flclient.cipher.struct.NewArray;
|
||||||
|
|
||||||
import mindspore.schema.GetClientList;
|
import mindspore.schema.GetClientList;
|
||||||
import mindspore.schema.ResponseCode;
|
import mindspore.schema.ResponseCode;
|
||||||
import mindspore.schema.ReturnClientList;
|
import mindspore.schema.ReturnClientList;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.time.LocalDateTime;
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
/**
|
||||||
import static com.mindspore.flclient.LocalFLParameter.IVEC_LEN;
|
* Define the serialization method, handle the response message returned from server for GetClientList request.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class ClientListReq {
|
public class ClientListReq {
|
||||||
|
|
||||||
private static final Logger LOGGER = Logger.getLogger(ClientListReq.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(ClientListReq.class.toString());
|
||||||
|
|
||||||
private FLCommunication flCommunication;
|
private FLCommunication flCommunication;
|
||||||
private String nextRequestTime;
|
private String nextRequestTime;
|
||||||
private FLParameter flParameter = FLParameter.getInstance();
|
private FLParameter flParameter = FLParameter.getInstance();
|
||||||
|
@ -64,34 +72,63 @@ public class ClientListReq {
|
||||||
return retCode;
|
return retCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus getClientList(int iteration, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
/**
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
* Send serialized request message of GetClientList to server.
|
||||||
|
*
|
||||||
|
* @param iteration current iteration of federated learning task.
|
||||||
|
* @param u3ClientList list of clients successfully requested in UpdateModel round.
|
||||||
|
* @param decryptSecretsList list to store to decrypted secret fragments.
|
||||||
|
* @param returnShareList List of returned secret fragments from server.
|
||||||
|
* @param cuvKeys Keys used to decrypt secret fragments.
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
|
public FLClientStatus getClientList(int iteration, List<String> u3ClientList,
|
||||||
|
List<DecryptShareSecrets> decryptSecretsList,
|
||||||
|
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||||
FlatBufferBuilder builder = new FlatBufferBuilder();
|
FlatBufferBuilder builder = new FlatBufferBuilder();
|
||||||
int id = builder.createString(localFLParameter.getFlID());
|
int id = builder.createString(localFLParameter.getFlID());
|
||||||
String dateTime = LocalDateTime.now().toString();
|
Date date = new Date();
|
||||||
|
long timestamp = date.getTime();
|
||||||
|
String dateTime = String.valueOf(timestamp);
|
||||||
int time = builder.createString(dateTime);
|
int time = builder.createString(dateTime);
|
||||||
int clientListRoot = GetClientList.createGetClientList(builder, id, iteration, time);
|
int clientListRoot = GetClientList.createGetClientList(builder, id, iteration, time);
|
||||||
builder.finish(clientListRoot);
|
builder.finish(clientListRoot);
|
||||||
byte[] msg = builder.sizedByteArray();
|
byte[] msg = builder.sizedByteArray();
|
||||||
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
|
||||||
try {
|
try {
|
||||||
byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg);
|
byte[] responseData = flCommunication.syncRequest(url + "/getClientList", msg);
|
||||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(responseData)) {
|
||||||
LOGGER.info(Common.addTag("[getClientList] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[getClientList] the server is not ready now, need wait some time and " +
|
||||||
|
"request again"));
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
nextRequestTime = "";
|
nextRequestTime = "";
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
}
|
}
|
||||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||||
|
LOGGER.info(Common.addTag("getClientList responseData size: " + responseData.length));
|
||||||
ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer);
|
ReturnClientList clientListRsp = ReturnClientList.getRootAsReturnClientList(buffer);
|
||||||
FLClientStatus status = judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);
|
return judgeGetClientList(clientListRsp, u3ClientList, decryptSecretsList, returnShareList, cuvKeys);
|
||||||
return status;
|
} catch (IOException ex) {
|
||||||
} catch (Exception e) {
|
LOGGER.severe(Common.addTag("[getClientList] unsolved error code in getClientList: catch IOException: " +
|
||||||
e.printStackTrace();
|
ex.getMessage()));
|
||||||
|
retCode = ResponseCode.RequestError;
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList, List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
/**
|
||||||
|
* Analyze the serialization message returned from server and perform corresponding processing.
|
||||||
|
*
|
||||||
|
* @param bufData Serialized message returned from server.
|
||||||
|
* @param u3ClientList list of clients successfully requested in UpdateModel round.
|
||||||
|
* @param decryptSecretsList list to store decrypted secret fragments.
|
||||||
|
* @param returnShareList List of returned secret fragments from server.
|
||||||
|
* @param cuvKeys Keys used to decrypt secret fragments.
|
||||||
|
* @return the status code corresponding to the response message.
|
||||||
|
*/
|
||||||
|
private FLClientStatus judgeGetClientList(ReturnClientList bufData, List<String> u3ClientList,
|
||||||
|
List<DecryptShareSecrets> decryptSecretsList,
|
||||||
|
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||||
retCode = bufData.retcode();
|
retCode = bufData.retcode();
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************"));
|
LOGGER.info(Common.addTag("[PairWiseMask] ************** the response of GetClientList **************"));
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||||
|
@ -109,18 +146,15 @@ public class ClientListReq {
|
||||||
String curFlId = bufData.clients(i);
|
String curFlId = bufData.clients(i);
|
||||||
u3ClientList.add(curFlId);
|
u3ClientList.add(curFlId);
|
||||||
}
|
}
|
||||||
try {
|
status = decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys);
|
||||||
decryptSecretShares(decryptSecretsList, returnShareList, cuvKeys);
|
return status;
|
||||||
} catch (Exception e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
return FLClientStatus.FAILED;
|
|
||||||
}
|
|
||||||
return FLClientStatus.SUCCESS;
|
|
||||||
case (ResponseCode.SucNotReady):
|
case (ResponseCode.SucNotReady):
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request GetClientList again!"));
|
LOGGER.info(Common.addTag("[PairWiseMask] server is not ready now, need wait and request " +
|
||||||
|
"GetClientList again!"));
|
||||||
return FLClientStatus.WAIT;
|
return FLClientStatus.WAIT;
|
||||||
case (ResponseCode.OutOfTime):
|
case (ResponseCode.OutOfTime):
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob again"));
|
LOGGER.info(Common.addTag("[PairWiseMask] GetClientList out of time: need wait and request startFLJob" +
|
||||||
|
" again"));
|
||||||
setNextRequestTime(bufData.nextReqTime());
|
setNextRequestTime(bufData.nextReqTime());
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
case (ResponseCode.RequestError):
|
case (ResponseCode.RequestError):
|
||||||
|
@ -128,36 +162,66 @@ public class ClientListReq {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList"));
|
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in GetClientList"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
default:
|
default:
|
||||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is invalid: " + retCode));
|
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReturnClientList is " +
|
||||||
|
"invalid: " + retCode));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList, List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) throws Exception {
|
private FLClientStatus decryptSecretShares(List<DecryptShareSecrets> decryptSecretsList,
|
||||||
|
List<EncryptShare> returnShareList, Map<String, byte[]> cuvKeys) {
|
||||||
decryptSecretsList.clear();
|
decryptSecretsList.clear();
|
||||||
int size = returnShareList.size();
|
int size = returnShareList.size();
|
||||||
|
if (size <= 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <returnShareList> is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
if (cuvKeys.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[PairWiseMask] the input argument <cuvKeys> is null"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
for (int i = 0; i < size; i++) {
|
for (int i = 0; i < size; i++) {
|
||||||
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
|
|
||||||
EncryptShare encryptShare = returnShareList.get(i);
|
EncryptShare encryptShare = returnShareList.get(i);
|
||||||
String vFlID = encryptShare.getFlID();
|
String vFlID = encryptShare.getFlID();
|
||||||
byte[] share = encryptShare.getShare().getArray();
|
byte[] share = encryptShare.getShare().getArray();
|
||||||
byte[] iVecIn = new byte[IVEC_LEN];
|
if (!cuvKeys.containsKey(vFlID)) {
|
||||||
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), iVecIn, "CBC");
|
LOGGER.severe(Common.addTag("[PairWiseMask] the key <vFlID> is not in map <cuvKeys> "));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
AESEncrypt aesEncrypt = new AESEncrypt(cuvKeys.get(vFlID), "CBC");
|
||||||
byte[] decryptShare = aesEncrypt.decrypt(cuvKeys.get(vFlID), share);
|
byte[] decryptShare = aesEncrypt.decrypt(cuvKeys.get(vFlID), share);
|
||||||
|
if (decryptShare == null || decryptShare.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[decryptSecretShares] the return byte[] is null, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
if (decryptShare.length < 4) {
|
||||||
|
LOGGER.severe(Common.addTag("[decryptSecretShares] the returned decryptShare is not valid: length is " +
|
||||||
|
"not right, please check!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
int sSize = (int) decryptShare[0];
|
int sSize = (int) decryptShare[0];
|
||||||
int bSize = (int) decryptShare[1];
|
int bSize = (int) decryptShare[1];
|
||||||
int sIndexLen = (int) decryptShare[2];
|
int sIndexLen = (int) decryptShare[2];
|
||||||
int bIndexLen = (int) decryptShare[3];
|
int bIndexLen = (int) decryptShare[3];
|
||||||
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen));
|
if (decryptShare.length < (4 + sIndexLen + bIndexLen + sSize + bSize)) {
|
||||||
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen, 4 + sIndexLen + bIndexLen));
|
LOGGER.severe(Common.addTag("[decryptSecretShares] the returned decryptShare is not valid: length is " +
|
||||||
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen, 4 + sIndexLen + bIndexLen + sSize);
|
"not right, please check!"));
|
||||||
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize, 4 + sIndexLen + bIndexLen + sSize + bSize);
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
byte[] sSkUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen,
|
||||||
|
4 + sIndexLen + bIndexLen + sSize);
|
||||||
|
byte[] bUv = Arrays.copyOfRange(decryptShare, 4 + sIndexLen + bIndexLen + sSize,
|
||||||
|
4 + sIndexLen + bIndexLen + sSize + bSize);
|
||||||
NewArray<byte[]> sSkVu = new NewArray<>();
|
NewArray<byte[]> sSkVu = new NewArray<>();
|
||||||
sSkVu.setSize(sSize);
|
sSkVu.setSize(sSize);
|
||||||
sSkVu.setArray(sSkUv);
|
sSkVu.setArray(sSkUv);
|
||||||
NewArray bVu = new NewArray();
|
NewArray bVu = new NewArray();
|
||||||
bVu.setSize(bSize);
|
bVu.setSize(bSize);
|
||||||
bVu.setArray(bUv);
|
bVu.setArray(bUv);
|
||||||
|
int sIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4, 4 + sIndexLen));
|
||||||
|
int bIndex = BaseUtil.byteArray2Integer(Arrays.copyOfRange(decryptShare, 4 + sIndexLen,
|
||||||
|
4 + sIndexLen + bIndexLen));
|
||||||
|
DecryptShareSecrets decryptShareSecrets = new DecryptShareSecrets();
|
||||||
decryptShareSecrets.setFlID(vFlID);
|
decryptShareSecrets.setFlID(vFlID);
|
||||||
decryptShareSecrets.setSSkVu(sSkVu);
|
decryptShareSecrets.setSSkVu(sSkVu);
|
||||||
decryptShareSecrets.setBVu(bVu);
|
decryptShareSecrets.setBVu(bVu);
|
||||||
|
@ -165,5 +229,6 @@ public class ClientListReq {
|
||||||
decryptShareSecrets.setIndexB(bIndex);
|
decryptShareSecrets.setIndexB(bIndex);
|
||||||
decryptSecretsList.add(decryptShareSecrets);
|
decryptSecretsList.add(decryptShareSecrets);
|
||||||
}
|
}
|
||||||
|
return FLClientStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
|
/*
|
||||||
/**
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,6 +16,10 @@
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher;
|
package com.mindspore.flclient.cipher;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.LocalFLParameter.KEY_LEN;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.Common;
|
||||||
|
|
||||||
import org.bouncycastle.crypto.digests.SHA256Digest;
|
import org.bouncycastle.crypto.digests.SHA256Digest;
|
||||||
import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator;
|
import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator;
|
||||||
import org.bouncycastle.crypto.params.KeyParameter;
|
import org.bouncycastle.crypto.params.KeyParameter;
|
||||||
|
@ -25,39 +28,80 @@ import org.bouncycastle.math.ec.rfc7748.X25519;
|
||||||
import java.security.SecureRandom;
|
import java.security.SecureRandom;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate public-private key pairs and DH Keys.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class KEYAgreement {
|
public class KEYAgreement {
|
||||||
private static final Logger LOGGER = Logger.getLogger(KEYAgreement.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(KEYAgreement.class.toString());
|
||||||
private static final int PBKDF2_ITERATIONS = 10000;
|
private static final int PBKDF2_ITERATIONS = 10000;
|
||||||
private static final int SALT_SIZE = 32;
|
|
||||||
private static final int HASH_BIT_SIZE = 256;
|
private static final int HASH_BIT_SIZE = 256;
|
||||||
private static final int KEY_LEN = X25519.SCALAR_SIZE;
|
|
||||||
|
|
||||||
private SecureRandom random = new SecureRandom();
|
private SecureRandom random = Common.getSecureRandom();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate private Key.
|
||||||
|
*
|
||||||
|
* @return the private Key.
|
||||||
|
*/
|
||||||
public byte[] generatePrivateKey() {
|
public byte[] generatePrivateKey() {
|
||||||
byte[] privateKey = new byte[KEY_LEN];
|
byte[] privateKey = new byte[KEY_LEN];
|
||||||
X25519.generatePrivateKey(random, privateKey);
|
X25519.generatePrivateKey(random, privateKey);
|
||||||
return privateKey;
|
return privateKey;
|
||||||
}
|
}
|
||||||
|
|
||||||
public byte[] generatePublicKey(byte[] privatekey) {
|
/**
|
||||||
|
* Use private Key to generate public Key.
|
||||||
|
*
|
||||||
|
* @param privateKey the private Key.
|
||||||
|
* @return the public Key.
|
||||||
|
*/
|
||||||
|
public byte[] generatePublicKey(byte[] privateKey) {
|
||||||
|
if (privateKey == null || privateKey.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("privateKey is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
byte[] publicKey = new byte[KEY_LEN];
|
byte[] publicKey = new byte[KEY_LEN];
|
||||||
X25519.generatePublicKey(privatekey, 0, publicKey, 0);
|
X25519.generatePublicKey(privateKey, 0, publicKey, 0);
|
||||||
return publicKey;
|
return publicKey;
|
||||||
}
|
}
|
||||||
|
|
||||||
public byte[] keyAgreement(byte[] privatekey, byte[] publicKey) {
|
/**
|
||||||
|
* Use private Key and public Key to generate DH Key.
|
||||||
|
*
|
||||||
|
* @param privateKey the private Key.
|
||||||
|
* @param publicKey the public Key.
|
||||||
|
* @return the DH Key.
|
||||||
|
*/
|
||||||
|
public byte[] keyAgreement(byte[] privateKey, byte[] publicKey) {
|
||||||
|
if (privateKey == null || privateKey.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("privateKey is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
|
if (publicKey == null || publicKey.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("publicKey is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
byte[] secret = new byte[KEY_LEN];
|
byte[] secret = new byte[KEY_LEN];
|
||||||
X25519.calculateAgreement(privatekey, 0, publicKey, 0, secret, 0);
|
X25519.calculateAgreement(privateKey, 0, publicKey, 0, secret, 0);
|
||||||
return secret;
|
return secret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Encrypt DH Key.
|
||||||
|
*
|
||||||
|
* @param password the DH Key.
|
||||||
|
* @param salt the salt value.
|
||||||
|
* @return encrypted DH Key.
|
||||||
|
*/
|
||||||
public byte[] getEncryptedPassword(byte[] password, byte[] salt) {
|
public byte[] getEncryptedPassword(byte[] password, byte[] salt) {
|
||||||
|
if (password == null || password.length == 0) {
|
||||||
byte[] saltB = new byte[SALT_SIZE];
|
LOGGER.severe(Common.addTag("password is null"));
|
||||||
|
return new byte[0];
|
||||||
|
}
|
||||||
PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest());
|
PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest());
|
||||||
gen.init(password, saltB, PBKDF2_ITERATIONS);
|
gen.init(password, salt, PBKDF2_ITERATIONS);
|
||||||
byte[] dk = ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey();
|
return ((KeyParameter) gen.generateDerivedParameters(HASH_BIT_SIZE)).getKey();
|
||||||
return dk;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,115 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package com.mindspore.flclient.cipher;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.Common;
|
||||||
|
|
||||||
|
import java.security.SecureRandom;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define the basic method for generating pairwise mask and individual mask.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
|
public class Masking {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(Masking.class.toString());
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Random generate RNG algorithm name.
|
||||||
|
*/
|
||||||
|
private static final String RNG_ALGORITHM = "SHA1PRNG";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate individual mask.
|
||||||
|
*
|
||||||
|
* @param secret used to store individual mask.
|
||||||
|
* @return the int value, 0 indicates success, -1 indicates failed .
|
||||||
|
*/
|
||||||
|
public int getRandomBytes(byte[] secret) {
|
||||||
|
if (secret == null || secret.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[Masking] the input argument <secret> is null, please check!"));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
SecureRandom secureRandom = Common.getSecureRandom();
|
||||||
|
secureRandom.nextBytes(secret);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate pairwise mask.
|
||||||
|
*
|
||||||
|
* @param noise used to store pairwise mask.
|
||||||
|
* @param length the length of individual mask.
|
||||||
|
* @param seed the seed for generate pairwise mask.
|
||||||
|
* @param iVec the IV value.
|
||||||
|
* @return the int value, 0 indicates success, -1 indicates failed .
|
||||||
|
*/
|
||||||
|
public int getMasking(List<Float> noise, int length, byte[] seed, byte[] iVec) {
|
||||||
|
if (length <= 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[Masking] the input argument <length> is not valid: <= 0, please check!"));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
int intV = Integer.SIZE / 8;
|
||||||
|
int size = length * intV;
|
||||||
|
byte[] data = new byte[size];
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
data[i] = 0;
|
||||||
|
}
|
||||||
|
AESEncrypt aesEncrypt = new AESEncrypt(seed, "CTR");
|
||||||
|
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data, iVec);
|
||||||
|
if (encryptCtr == null || encryptCtr.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("[Masking] the return byte[] is null, please check!"));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
int[] sub = new int[intV];
|
||||||
|
for (int j = 0; j < 4; j++) {
|
||||||
|
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
|
||||||
|
}
|
||||||
|
int subI = byte2int(sub, 4);
|
||||||
|
if (subI == -1) {
|
||||||
|
LOGGER.severe(Common.addTag("[Masking] the the returned <subI> is not valid: -1, please check!"));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
Float fNoise = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
|
||||||
|
noise.add(fNoise);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int byte2int(int[] data, int number) {
|
||||||
|
if (data.length < 4) {
|
||||||
|
LOGGER.severe(Common.addTag("[Masking] the input argument <data> is not valid: length < 4, please check!"));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
switch (number) {
|
||||||
|
case 1:
|
||||||
|
return (int) data[0];
|
||||||
|
case 2:
|
||||||
|
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
|
||||||
|
case 3:
|
||||||
|
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
|
||||||
|
case 4:
|
||||||
|
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
|
||||||
|
| (data[3] << 24 & 0xff000000);
|
||||||
|
default:
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,82 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher;
|
|
||||||
|
|
||||||
import java.security.NoSuchAlgorithmException;
|
|
||||||
import java.security.SecureRandom;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.logging.Logger;
|
|
||||||
|
|
||||||
public class Random {
|
|
||||||
/**
|
|
||||||
* random generate RNG algorithm name
|
|
||||||
*/
|
|
||||||
private static final Logger LOGGER = Logger.getLogger(Random.class.toString());
|
|
||||||
private static final String RNG_ALGORITHM = "SHA1PRNG";
|
|
||||||
|
|
||||||
private static final int RANDOM_LEN = 128 / 8;
|
|
||||||
|
|
||||||
public void getRandomBytes(byte[] secret) {
|
|
||||||
try {
|
|
||||||
SecureRandom secureRandom = SecureRandom.getInstance("SHA1PRNG");
|
|
||||||
secureRandom.nextBytes(secret);
|
|
||||||
} catch (NoSuchAlgorithmException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void randomAESCTR(List<Float> noise, int length, byte[] seed) throws Exception {
|
|
||||||
int intV = Integer.SIZE / 8;
|
|
||||||
int size = length * intV;
|
|
||||||
byte[] data = new byte[size];
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
data[i] = 0;
|
|
||||||
}
|
|
||||||
byte[] ivec = new byte[RANDOM_LEN];
|
|
||||||
AESEncrypt aesEncrypt = new AESEncrypt(seed, ivec, "CTR");
|
|
||||||
byte[] encryptCtr = aesEncrypt.encryptCTR(seed, data);
|
|
||||||
for (int i = 0; i < length; i++) {
|
|
||||||
int[] sub = new int[intV];
|
|
||||||
for (int j = 0; j < 4; j++) {
|
|
||||||
sub[j] = (int) encryptCtr[i * intV + j] & 0xff;
|
|
||||||
}
|
|
||||||
int subI = byte2int(sub, 4);
|
|
||||||
|
|
||||||
Float f = Float.valueOf(Float.valueOf(subI) / Integer.MAX_VALUE);
|
|
||||||
noise.add(f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static int byte2int(int[] data, int n) {
|
|
||||||
switch (n) {
|
|
||||||
case 1:
|
|
||||||
return (int) data[0];
|
|
||||||
case 2:
|
|
||||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00);
|
|
||||||
case 3:
|
|
||||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000);
|
|
||||||
case 4:
|
|
||||||
return (int) (data[0] & 0xff) | (data[1] << 8 & 0xff00) | (data[2] << 16 & 0xff0000)
|
|
||||||
| (data[3] << 24 & 0xff000000);
|
|
||||||
default:
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,23 +16,33 @@
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher;
|
package com.mindspore.flclient.cipher;
|
||||||
|
|
||||||
|
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
|
|
||||||
import com.mindspore.flclient.Common;
|
import com.mindspore.flclient.Common;
|
||||||
import com.mindspore.flclient.FLClientStatus;
|
import com.mindspore.flclient.FLClientStatus;
|
||||||
import com.mindspore.flclient.FLCommunication;
|
import com.mindspore.flclient.FLCommunication;
|
||||||
import com.mindspore.flclient.FLParameter;
|
import com.mindspore.flclient.FLParameter;
|
||||||
import com.mindspore.flclient.LocalFLParameter;
|
import com.mindspore.flclient.LocalFLParameter;
|
||||||
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
import com.mindspore.flclient.cipher.struct.DecryptShareSecrets;
|
||||||
import mindspore.schema.ClientShare;
|
|
||||||
import mindspore.schema.ResponseCode;
|
|
||||||
|
|
||||||
|
import mindspore.schema.ClientShare;
|
||||||
|
import mindspore.schema.ReconstructSecret;
|
||||||
|
import mindspore.schema.ResponseCode;
|
||||||
|
import mindspore.schema.SendReconstructSecret;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.time.LocalDateTime;
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
import static com.mindspore.flclient.FLParameter.SLEEP_TIME;
|
/**
|
||||||
|
* reconstruct secret request
|
||||||
|
*
|
||||||
|
* @since 2021-8-27
|
||||||
|
*/
|
||||||
public class ReconstructSecretReq {
|
public class ReconstructSecretReq {
|
||||||
private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(ReconstructSecretReq.class.toString());
|
||||||
private FLCommunication flCommunication;
|
private FLCommunication flCommunication;
|
||||||
|
@ -41,36 +51,44 @@ public class ReconstructSecretReq {
|
||||||
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
private LocalFLParameter localFLParameter = LocalFLParameter.getInstance();
|
||||||
private int retCode;
|
private int retCode;
|
||||||
|
|
||||||
public String getNextRequestTime() {
|
/**
|
||||||
return nextRequestTime;
|
* reconstruct secret request
|
||||||
}
|
*/
|
||||||
|
|
||||||
public void setNextRequestTime(String nextRequestTime) {
|
|
||||||
this.nextRequestTime = nextRequestTime;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getRetCode() {
|
|
||||||
return retCode;
|
|
||||||
}
|
|
||||||
|
|
||||||
public ReconstructSecretReq() {
|
public ReconstructSecretReq() {
|
||||||
flCommunication = FLCommunication.getInstance();
|
flCommunication = FLCommunication.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList, List<String> u3ClientList, int iteration) {
|
/**
|
||||||
String url = Common.generateUrl(flParameter.isUseHttps(), flParameter.isUseElb(), flParameter.getIp(), flParameter.getPort(), flParameter.getServerNum());
|
* send secret shards to server
|
||||||
|
*
|
||||||
|
* @param decryptShareSecretsList secret shards list
|
||||||
|
* @param u3ClientList u3 client list
|
||||||
|
* @param iteration iter number
|
||||||
|
* @return request result
|
||||||
|
*/
|
||||||
|
public FLClientStatus sendReconstructSecret(List<DecryptShareSecrets> decryptShareSecretsList,
|
||||||
|
List<String> u3ClientList, int iteration) {
|
||||||
|
String url = Common.generateUrl(flParameter.isUseElb(), flParameter.getServerNum(), flParameter.getDomainName());
|
||||||
FlatBufferBuilder builder = new FlatBufferBuilder();
|
FlatBufferBuilder builder = new FlatBufferBuilder();
|
||||||
int desFlId = builder.createString(localFLParameter.getFlID());
|
int desFlId = builder.createString(localFLParameter.getFlID());
|
||||||
String dateTime = LocalDateTime.now().toString();
|
Date date = new Date();
|
||||||
|
long timestamp = date.getTime();
|
||||||
|
String dateTime = String.valueOf(timestamp);
|
||||||
int time = builder.createString(dateTime);
|
int time = builder.createString(dateTime);
|
||||||
int shareSecretsSize = decryptShareSecretsList.size();
|
int shareSecretsSize = decryptShareSecretsList.size();
|
||||||
if (shareSecretsSize <= 0) {
|
if (shareSecretsSize <= 0) {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please waite."));
|
LOGGER.info(Common.addTag("[PairWiseMask] request failed: the decryptShareSecretsList is null, please " +
|
||||||
|
"waite."));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
} else {
|
} else {
|
||||||
int[] decryptShareList = new int[shareSecretsSize];
|
int[] decryptShareList = new int[shareSecretsSize];
|
||||||
for (int i = 0; i < shareSecretsSize; i++) {
|
for (int i = 0; i < shareSecretsSize; i++) {
|
||||||
DecryptShareSecrets decryptShareSecrets = decryptShareSecretsList.get(i);
|
DecryptShareSecrets decryptShareSecrets = decryptShareSecretsList.get(i);
|
||||||
|
if (decryptShareSecrets.getFlID() == null) {
|
||||||
|
LOGGER.severe(Common.addTag("[PairWiseMask] get remote flID failed!"));
|
||||||
|
return FLClientStatus.FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
String srcFlId = decryptShareSecrets.getFlID();
|
String srcFlId = decryptShareSecrets.getFlID();
|
||||||
byte[] share;
|
byte[] share;
|
||||||
int index;
|
int index;
|
||||||
|
@ -86,31 +104,33 @@ public class ReconstructSecretReq {
|
||||||
int clientShare = ClientShare.createClientShare(builder, fbsSrcFlId, fbsShare, index);
|
int clientShare = ClientShare.createClientShare(builder, fbsSrcFlId, fbsShare, index);
|
||||||
decryptShareList[i] = clientShare;
|
decryptShareList[i] = clientShare;
|
||||||
}
|
}
|
||||||
int reconstructShareSecrets = mindspore.schema.SendReconstructSecret.createReconstructSecretSharesVector(builder, decryptShareList);
|
int reconstructShareSecrets = SendReconstructSecret.createReconstructSecretSharesVector(builder,
|
||||||
int reconstructSecretRoot = mindspore.schema.SendReconstructSecret.createSendReconstructSecret(builder, desFlId, reconstructShareSecrets, iteration, time);
|
decryptShareList);
|
||||||
|
int reconstructSecretRoot = SendReconstructSecret.createSendReconstructSecret(builder, desFlId,
|
||||||
|
reconstructShareSecrets, iteration, time);
|
||||||
builder.finish(reconstructSecretRoot);
|
builder.finish(reconstructSecretRoot);
|
||||||
byte[] msg = builder.sizedByteArray();
|
byte[] msg = builder.sizedByteArray();
|
||||||
try {
|
try {
|
||||||
byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg);
|
byte[] responseData = flCommunication.syncRequest(url + "/reconstructSecrets", msg);
|
||||||
if (Common.isSafeMod(responseData, localFLParameter.getSafeMod())) {
|
if (!Common.isSeverReady(responseData)) {
|
||||||
LOGGER.info(Common.addTag("[sendReconstructSecret] The cluster is in safemode, need wait some time and request again"));
|
LOGGER.info(Common.addTag("[sendReconstructSecret] the server is not ready now, need wait some " +
|
||||||
|
"time and request again"));
|
||||||
Common.sleep(SLEEP_TIME);
|
Common.sleep(SLEEP_TIME);
|
||||||
nextRequestTime = "";
|
nextRequestTime = "";
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
}
|
}
|
||||||
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
ByteBuffer buffer = ByteBuffer.wrap(responseData);
|
||||||
mindspore.schema.ReconstructSecret reconstructSecretRsp = mindspore.schema.ReconstructSecret.getRootAsReconstructSecret(buffer);
|
ReconstructSecret reconstructSecretRsp = ReconstructSecret.getRootAsReconstructSecret(buffer);
|
||||||
FLClientStatus status = judgeSendReconstructSecrets(reconstructSecretRsp);
|
return judgeSendReconstructSecrets(reconstructSecretRsp);
|
||||||
return status;
|
} catch (IOException ex) {
|
||||||
} catch (Exception e) {
|
|
||||||
LOGGER.severe(Common.addTag("[PairWiseMask] un solved error code in reconstruct"));
|
LOGGER.severe(Common.addTag("[PairWiseMask] un solved error code in reconstruct"));
|
||||||
e.printStackTrace();
|
ex.printStackTrace();
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public FLClientStatus judgeSendReconstructSecrets(mindspore.schema.ReconstructSecret bufData) {
|
private FLClientStatus judgeSendReconstructSecrets(ReconstructSecret bufData) {
|
||||||
retCode = bufData.retcode();
|
retCode = bufData.retcode();
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************"));
|
LOGGER.info(Common.addTag("[PairWiseMask] **************the response of SendReconstructSecrets**************"));
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
LOGGER.info(Common.addTag("[PairWiseMask] return code: " + retCode));
|
||||||
|
@ -122,7 +142,8 @@ public class ReconstructSecretReq {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success"));
|
LOGGER.info(Common.addTag("[PairWiseMask] ReconstructSecrets success"));
|
||||||
return FLClientStatus.SUCCESS;
|
return FLClientStatus.SUCCESS;
|
||||||
case (ResponseCode.OutOfTime):
|
case (ResponseCode.OutOfTime):
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request startFLJob again"));
|
LOGGER.info(Common.addTag("[PairWiseMask] SendReconstructSecrets out of time: need wait and request " +
|
||||||
|
"startFLJob again"));
|
||||||
setNextRequestTime(bufData.nextReqTime());
|
setNextRequestTime(bufData.nextReqTime());
|
||||||
return FLClientStatus.RESTART;
|
return FLClientStatus.RESTART;
|
||||||
case (ResponseCode.RequestError):
|
case (ResponseCode.RequestError):
|
||||||
|
@ -130,8 +151,36 @@ public class ReconstructSecretReq {
|
||||||
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets"));
|
LOGGER.info(Common.addTag("[PairWiseMask] catch SucNotMatch or SystemError in SendReconstructSecrets"));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
default:
|
default:
|
||||||
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is invalid: " + retCode));
|
LOGGER.severe(Common.addTag("[PairWiseMask] the return <retCode> from server in ReconstructSecret is " +
|
||||||
|
"invalid: " + retCode));
|
||||||
return FLClientStatus.FAILED;
|
return FLClientStatus.FAILED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get next request time
|
||||||
|
*
|
||||||
|
* @return next request time
|
||||||
|
*/
|
||||||
|
public String getNextRequestTime() {
|
||||||
|
return nextRequestTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set next request time
|
||||||
|
*
|
||||||
|
* @param nextRequestTime next request time
|
||||||
|
*/
|
||||||
|
public void setNextRequestTime(String nextRequestTime) {
|
||||||
|
this.nextRequestTime = nextRequestTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get retCode
|
||||||
|
*
|
||||||
|
* @return retCode
|
||||||
|
*/
|
||||||
|
public int getRetCode() {
|
||||||
|
return retCode;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,115 +22,158 @@ import java.math.BigInteger;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
import java.util.logging.Logger;
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Define functions that for splitting secret and combining secret shards.
|
||||||
|
*
|
||||||
|
* @since 2021-06-30
|
||||||
|
*/
|
||||||
public class ShareSecrets {
|
public class ShareSecrets {
|
||||||
private static final Logger LOGGER = Logger.getLogger(ShareSecrets.class.toString());
|
private static final Logger LOGGER = Logger.getLogger(ShareSecrets.class.toString());
|
||||||
|
|
||||||
public final class SecretShare {
|
private BigInteger prime;
|
||||||
public SecretShare(final int num, final BigInteger share) {
|
private final int minNum;
|
||||||
this.num = num;
|
private final int totalNum;
|
||||||
this.share = share;
|
private final Random random;
|
||||||
}
|
|
||||||
|
|
||||||
public int getNum() {
|
/**
|
||||||
return num;
|
* Defines the constructor of the class ShareSecrets.
|
||||||
|
*
|
||||||
|
* @param minNum minimum number of fragments required to reconstruct a secret.
|
||||||
|
* @param totalNum total clients number.
|
||||||
|
*/
|
||||||
|
public ShareSecrets(final int minNum, final int totalNum) {
|
||||||
|
if (minNum <= 0) {
|
||||||
|
LOGGER.severe(Common.addTag("the argument <k> is not valid: <= 0, it should be > 0"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
|
if (totalNum <= 0) {
|
||||||
public BigInteger getShare() {
|
LOGGER.severe(Common.addTag("the argument <n> is not valid: <= 0, it should be > 0"));
|
||||||
return share;
|
throw new IllegalArgumentException();
|
||||||
}
|
}
|
||||||
|
if (minNum > totalNum) {
|
||||||
@Override
|
LOGGER.severe(Common.addTag("the argument <k, n> is not valid: k > n, it should k <= n"));
|
||||||
public String toString() {
|
throw new IllegalArgumentException();
|
||||||
return "SecretShare [num=" + num + ", share=" + share + "]";
|
|
||||||
}
|
}
|
||||||
|
this.minNum = minNum;
|
||||||
private final int num;
|
this.totalNum = totalNum;
|
||||||
private final BigInteger share;
|
random = Common.getSecureRandom();
|
||||||
}
|
}
|
||||||
|
|
||||||
public ShareSecrets(final int k, final int n) {
|
/**
|
||||||
this.k = k;
|
* Splits a secret into a specified number of secret fragments.
|
||||||
this.n = n;
|
*
|
||||||
|
* @param bytes the secret need to be split.
|
||||||
random = new Random();
|
* @param primeByte teh big prime number used to combine secret fragments.
|
||||||
}
|
* @return the secret fragments.
|
||||||
|
*/
|
||||||
public SecretShare[] split(final byte[] bytes, byte[] primeByte) {
|
public SecretShares[] split(final byte[] bytes, byte[] primeByte) {
|
||||||
|
if (bytes == null || bytes.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("the input argument <bytes> is null"));
|
||||||
|
return new SecretShares[0];
|
||||||
|
}
|
||||||
|
if (primeByte == null || primeByte.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("the input argument <primeByte> is null"));
|
||||||
|
return new SecretShares[0];
|
||||||
|
}
|
||||||
BigInteger secret = BaseUtil.byteArray2BigInteger(bytes);
|
BigInteger secret = BaseUtil.byteArray2BigInteger(bytes);
|
||||||
final int modLength = secret.bitLength() + 1;
|
final int modLength = secret.bitLength() + 1;
|
||||||
prime = BaseUtil.byteArray2BigInteger(primeByte);
|
prime = BaseUtil.byteArray2BigInteger(primeByte);
|
||||||
final BigInteger[] coeff = new BigInteger[k - 1];
|
final BigInteger[] coefficient = new BigInteger[minNum - 1];
|
||||||
|
|
||||||
LOGGER.info(Common.addTag("Prime Number: " + prime));
|
for (int i = 0; i < minNum - 1; i++) {
|
||||||
|
coefficient[i] = randomZp(prime);
|
||||||
for (int i = 0; i < k - 1; i++) {
|
|
||||||
coeff[i] = randomZp(prime);
|
|
||||||
LOGGER.info(Common.addTag("a" + (i + 1) + ": " + coeff[i]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
final SecretShare[] shares = new SecretShare[n];
|
final SecretShares[] shares = new SecretShares[totalNum];
|
||||||
for (int i = 1; i <= n; i++) {
|
for (int i = 1; i <= totalNum; i++) {
|
||||||
BigInteger accum = secret;
|
BigInteger accumulate = secret;
|
||||||
|
|
||||||
for (int j = 1; j < k; j++) {
|
for (int j = 1; j < minNum; j++) {
|
||||||
final BigInteger t1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime);
|
final BigInteger b1 = BigInteger.valueOf(i).modPow(BigInteger.valueOf(j), prime);
|
||||||
final BigInteger t2 = coeff[j - 1].multiply(t1).mod(prime);
|
final BigInteger b2 = coefficient[j - 1].multiply(b1).mod(prime);
|
||||||
|
|
||||||
accum = accum.add(t2).mod(prime);
|
accumulate = accumulate.add(b2).mod(prime);
|
||||||
}
|
}
|
||||||
shares[i - 1] = new SecretShare(i, accum);
|
shares[i - 1] = new SecretShares(i, accumulate);
|
||||||
LOGGER.info(Common.addTag("Share " + shares[i - 1]));
|
|
||||||
}
|
}
|
||||||
return shares;
|
return shares;
|
||||||
}
|
}
|
||||||
|
|
||||||
public BigInteger getPrime() {
|
/**
|
||||||
return prime;
|
* Combine secret fragments.
|
||||||
}
|
*
|
||||||
|
* @param shares the secret fragments.
|
||||||
public BigInteger combine(final SecretShare[] shares, final byte[] primeByte) {
|
* @param primeByte teh big prime number used to combine secret fragments.
|
||||||
|
* @return the secrets combined by secret fragments.
|
||||||
|
*/
|
||||||
|
public BigInteger combine(final SecretShares[] shares, final byte[] primeByte) {
|
||||||
|
if (shares == null || shares.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("the input argument <shares> is null"));
|
||||||
|
return BigInteger.ZERO;
|
||||||
|
}
|
||||||
|
if (primeByte == null || primeByte.length == 0) {
|
||||||
|
LOGGER.severe(Common.addTag("the input argument <primeByte> is null"));
|
||||||
|
return BigInteger.ZERO;
|
||||||
|
}
|
||||||
BigInteger primeNum = BaseUtil.byteArray2BigInteger(primeByte);
|
BigInteger primeNum = BaseUtil.byteArray2BigInteger(primeByte);
|
||||||
BigInteger accum = BigInteger.ZERO;
|
BigInteger accumulate = BigInteger.ZERO;
|
||||||
for (int j = 0; j < k; j++) {
|
for (int j = 0; j < minNum; j++) {
|
||||||
BigInteger num = BigInteger.ONE;
|
BigInteger num = BigInteger.ONE;
|
||||||
BigInteger den = BigInteger.ONE;
|
BigInteger den = BigInteger.ONE;
|
||||||
|
|
||||||
BigInteger tmp;
|
BigInteger tmp;
|
||||||
|
|
||||||
for (int m = 0; m < k; m++) {
|
for (int m = 0; m < minNum; m++) {
|
||||||
if (j != m) {
|
if (j != m) {
|
||||||
num = num.multiply(BigInteger.valueOf(shares[m].getNum())).mod(primeNum);
|
num = num.multiply(BigInteger.valueOf(shares[m].getNumber())).mod(primeNum);
|
||||||
tmp = BigInteger.valueOf(shares[j].getNum()).multiply(BigInteger.valueOf(-1));
|
tmp = BigInteger.valueOf(shares[j].getNumber()).multiply(BigInteger.valueOf(-1));
|
||||||
tmp = BigInteger.valueOf(shares[m].getNum()).add(tmp).mod(primeNum);
|
tmp = BigInteger.valueOf(shares[m].getNumber()).add(tmp).mod(primeNum);
|
||||||
den = den.multiply(tmp).mod(primeNum);
|
den = den.multiply(tmp).mod(primeNum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
final BigInteger value = shares[j].getShare();
|
final BigInteger value = shares[j].getShares();
|
||||||
|
|
||||||
tmp = den.modInverse(primeNum);
|
tmp = den.modInverse(primeNum);
|
||||||
tmp = tmp.multiply(num).mod(primeNum);
|
tmp = tmp.multiply(num).mod(primeNum);
|
||||||
tmp = tmp.multiply(value).mod(primeNum);
|
tmp = tmp.multiply(value).mod(primeNum);
|
||||||
accum = accum.add(tmp).mod(primeNum);
|
accumulate = accumulate.add(tmp).mod(primeNum);
|
||||||
LOGGER.info(Common.addTag("value: " + value + ", tmp: " + tmp + ", accum: " + accum));
|
|
||||||
}
|
}
|
||||||
LOGGER.info(Common.addTag("The secret is: " + accum));
|
return accumulate;
|
||||||
return accum;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private BigInteger randomZp(final BigInteger p) {
|
private BigInteger randomZp(final BigInteger num) {
|
||||||
while (true) {
|
while (true) {
|
||||||
final BigInteger r = new BigInteger(p.bitLength(), random);
|
final BigInteger rand = new BigInteger(num.bitLength(), random);
|
||||||
if (r.compareTo(BigInteger.ZERO) > 0 && r.compareTo(p) < 0) {
|
if (rand.compareTo(BigInteger.ZERO) > 0 && rand.compareTo(num) < 0) {
|
||||||
return r;
|
return rand;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private BigInteger prime;
|
/**
|
||||||
private final int k;
|
* Define the structure for store secret fragments.
|
||||||
private final int n;
|
*/
|
||||||
private final Random random;
|
public final class SecretShares {
|
||||||
private final int SECRET_MAX_LEN = 32;
|
private final int number;
|
||||||
|
private final BigInteger share;
|
||||||
|
|
||||||
|
public SecretShares(final int number, final BigInteger share) {
|
||||||
|
this.number = number;
|
||||||
|
this.share = share;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumber() {
|
||||||
|
return number;
|
||||||
|
}
|
||||||
|
|
||||||
|
public BigInteger getShares() {
|
||||||
|
return share;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "SecretShares [number=" + number + ", share=" + share + "]";
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,33 +16,114 @@
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher.struct;
|
package com.mindspore.flclient.cipher.struct;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.Common;
|
||||||
|
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* public key class of secure aggregation
|
||||||
|
*
|
||||||
|
* @since 2021-8-27
|
||||||
|
*/
|
||||||
public class ClientPublicKey {
|
public class ClientPublicKey {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(ClientPublicKey.class.toString());
|
||||||
private String flID;
|
private String flID;
|
||||||
private NewArray<byte[]> cPK;
|
private NewArray<byte[]> cPK;
|
||||||
private NewArray<byte[]> sPk;
|
private NewArray<byte[]> sPk;
|
||||||
|
private NewArray<byte[]> pwIv;
|
||||||
|
private NewArray<byte[]> pwSalt;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get client's flID
|
||||||
|
*
|
||||||
|
* @return flID of this client
|
||||||
|
*/
|
||||||
public String getFlID() {
|
public String getFlID() {
|
||||||
|
if (flID == null || flID.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[ClientPublicKey] the parameter of <flID> is null, please set it before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
return flID;
|
return flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set client's flID
|
||||||
|
*
|
||||||
|
* @param flID hash value used for identify client
|
||||||
|
*/
|
||||||
public void setFlID(String flID) {
|
public void setFlID(String flID) {
|
||||||
this.flID = flID;
|
this.flID = flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get CPK of secure aggregation
|
||||||
|
*
|
||||||
|
* @return CPK of secure aggregation
|
||||||
|
*/
|
||||||
public NewArray<byte[]> getCPK() {
|
public NewArray<byte[]> getCPK() {
|
||||||
return cPK;
|
return cPK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set CPK of secure aggregation
|
||||||
|
*
|
||||||
|
* @param cPK public key used for encryption
|
||||||
|
*/
|
||||||
public void setCPK(NewArray<byte[]> cPK) {
|
public void setCPK(NewArray<byte[]> cPK) {
|
||||||
this.cPK = cPK;
|
this.cPK = cPK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get SPK of secure aggregation
|
||||||
|
*
|
||||||
|
* @return SPK of secure aggregation
|
||||||
|
*/
|
||||||
public NewArray<byte[]> getSPK() {
|
public NewArray<byte[]> getSPK() {
|
||||||
return sPk;
|
return sPk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set SPK of secure aggregation
|
||||||
|
*
|
||||||
|
* @param sPk public key used for encryption
|
||||||
|
*/
|
||||||
public void setSPK(NewArray<byte[]> sPk) {
|
public void setSPK(NewArray<byte[]> sPk) {
|
||||||
this.sPk = sPk;
|
this.sPk = sPk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get the IV value used for pairwise encrypt
|
||||||
|
*
|
||||||
|
* @return the IV value used for pairwise encrypt
|
||||||
|
*/
|
||||||
|
public NewArray<byte[]> getPwIv() {
|
||||||
|
return pwIv;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set the IV value used for pairwise encrypt
|
||||||
|
*
|
||||||
|
* @param pwIv IV value used for pairwise encrypt
|
||||||
|
*/
|
||||||
|
public void setPwIv(NewArray<byte[]> pwIv) {
|
||||||
|
this.pwIv = pwIv;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get salt value for secure aggregation
|
||||||
|
*
|
||||||
|
* @return salt value for secure aggregation
|
||||||
|
*/
|
||||||
|
public NewArray<byte[]> getPwSalt() {
|
||||||
|
return pwSalt;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set salt value for secure aggregation
|
||||||
|
*
|
||||||
|
* @param pwSalt salt value for secure aggregation
|
||||||
|
*/
|
||||||
|
public void setPwSalt(NewArray<byte[]> pwSalt) {
|
||||||
|
this.pwSalt = pwSalt;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,49 +16,114 @@
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher.struct;
|
package com.mindspore.flclient.cipher.struct;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.Common;
|
||||||
|
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* class used for set and get decryption shards
|
||||||
|
*
|
||||||
|
* @since 2021-8-27
|
||||||
|
*/
|
||||||
public class DecryptShareSecrets {
|
public class DecryptShareSecrets {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(DecryptShareSecrets.class.toString());
|
||||||
private String flID;
|
private String flID;
|
||||||
private NewArray<byte[]> sSkVu;
|
private NewArray<byte[]> sSkVu;
|
||||||
private NewArray<byte[]> bVu;
|
private NewArray<byte[]> bVu;
|
||||||
private int sIndex;
|
private int sIndex;
|
||||||
private int indexB;
|
private int indexB;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get flID of client
|
||||||
|
*
|
||||||
|
* @return flID of this client
|
||||||
|
*/
|
||||||
public String getFlID() {
|
public String getFlID() {
|
||||||
|
if (flID == null || flID.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
|
||||||
|
"use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
return flID;
|
return flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set flID for this client
|
||||||
|
*
|
||||||
|
* @param flID hash value used for identify client
|
||||||
|
*/
|
||||||
public void setFlID(String flID) {
|
public void setFlID(String flID) {
|
||||||
this.flID = flID;
|
this.flID = flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get secret key shards
|
||||||
|
*
|
||||||
|
* @return secret key shards
|
||||||
|
*/
|
||||||
public NewArray<byte[]> getSSkVu() {
|
public NewArray<byte[]> getSSkVu() {
|
||||||
return sSkVu;
|
return sSkVu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set secret key shards
|
||||||
|
*
|
||||||
|
* @param sSkVu secret key shards
|
||||||
|
*/
|
||||||
public void setSSkVu(NewArray<byte[]> sSkVu) {
|
public void setSSkVu(NewArray<byte[]> sSkVu) {
|
||||||
this.sSkVu = sSkVu;
|
this.sSkVu = sSkVu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get bu shards
|
||||||
|
*
|
||||||
|
* @return bu shards
|
||||||
|
*/
|
||||||
public NewArray<byte[]> getBVu() {
|
public NewArray<byte[]> getBVu() {
|
||||||
return bVu;
|
return bVu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set bu shards
|
||||||
|
*
|
||||||
|
* @param bVu bu shards used for secure aggregation
|
||||||
|
*/
|
||||||
public void setBVu(NewArray<byte[]> bVu) {
|
public void setBVu(NewArray<byte[]> bVu) {
|
||||||
this.bVu = bVu;
|
this.bVu = bVu;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get index of secret shards
|
||||||
|
*
|
||||||
|
* @return index of secret shards
|
||||||
|
*/
|
||||||
public int getSIndex() {
|
public int getSIndex() {
|
||||||
return sIndex;
|
return sIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set index of secret shards
|
||||||
|
*
|
||||||
|
* @param sIndex index of secret shards
|
||||||
|
*/
|
||||||
public void setSIndex(int sIndex) {
|
public void setSIndex(int sIndex) {
|
||||||
this.sIndex = sIndex;
|
this.sIndex = sIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get index of bu shards
|
||||||
|
*
|
||||||
|
* @return index of bu shards
|
||||||
|
*/
|
||||||
public int getIndexB() {
|
public int getIndexB() {
|
||||||
return indexB;
|
return indexB;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set index of bu shards
|
||||||
|
*
|
||||||
|
* @param indexB index of bu shards
|
||||||
|
*/
|
||||||
public void setIndexB(int indexB) {
|
public void setIndexB(int indexB) {
|
||||||
this.indexB = indexB;
|
this.indexB = indexB;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,22 +16,57 @@
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher.struct;
|
package com.mindspore.flclient.cipher.struct;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.Common;
|
||||||
|
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* class used for encrypt shares of secret
|
||||||
|
*
|
||||||
|
* @since 2021-8-27
|
||||||
|
*/
|
||||||
public class EncryptShare {
|
public class EncryptShare {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(DecryptShareSecrets.class.toString());
|
||||||
private String flID;
|
private String flID;
|
||||||
private NewArray<byte[]> share;
|
private NewArray<byte[]> share;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get client's flID
|
||||||
|
*
|
||||||
|
* @return flID of this client
|
||||||
|
*/
|
||||||
public String getFlID() {
|
public String getFlID() {
|
||||||
|
if (flID == null || flID.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[DecryptShareSecrets] the parameter of <flID> is null, please set it before " +
|
||||||
|
"use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
return flID;
|
return flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set client's flID
|
||||||
|
*
|
||||||
|
* @param flID hash value used for identify client
|
||||||
|
*/
|
||||||
public void setFlID(String flID) {
|
public void setFlID(String flID) {
|
||||||
this.flID = flID;
|
this.flID = flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get secret share
|
||||||
|
*
|
||||||
|
* @return secret share
|
||||||
|
*/
|
||||||
public NewArray<byte[]> getShare() {
|
public NewArray<byte[]> getShare() {
|
||||||
return share;
|
return share;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set secret share
|
||||||
|
*
|
||||||
|
* @param share secret share
|
||||||
|
*/
|
||||||
public void setShare(NewArray<byte[]> share) {
|
public void setShare(NewArray<byte[]> share) {
|
||||||
this.share = share;
|
this.share = share;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,24 +16,50 @@
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher.struct;
|
package com.mindspore.flclient.cipher.struct;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* class used define new array type
|
||||||
|
*
|
||||||
|
* @param <T> an array
|
||||||
|
*
|
||||||
|
* @since 2021-8-27
|
||||||
|
*/
|
||||||
public class NewArray<T> {
|
public class NewArray<T> {
|
||||||
private int size;
|
private int size;
|
||||||
private T array;
|
private T array;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get array size
|
||||||
|
*
|
||||||
|
* @return array size
|
||||||
|
*/
|
||||||
public int getSize() {
|
public int getSize() {
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set array size
|
||||||
|
*
|
||||||
|
* @param size array size
|
||||||
|
*/
|
||||||
public void setSize(int size) {
|
public void setSize(int size) {
|
||||||
this.size = size;
|
this.size = size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get array
|
||||||
|
*
|
||||||
|
* @return an array
|
||||||
|
*/
|
||||||
public T getArray() {
|
public T getArray() {
|
||||||
return array;
|
return array;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set array
|
||||||
|
*
|
||||||
|
* @param array input
|
||||||
|
*/
|
||||||
public void setArray(T array) {
|
public void setArray(T array) {
|
||||||
this.array = array;
|
this.array = array;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/*
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -16,31 +16,75 @@
|
||||||
|
|
||||||
package com.mindspore.flclient.cipher.struct;
|
package com.mindspore.flclient.cipher.struct;
|
||||||
|
|
||||||
|
import com.mindspore.flclient.Common;
|
||||||
|
|
||||||
|
import java.util.logging.Logger;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* share secret class
|
||||||
|
*
|
||||||
|
* @since 2021-8-27
|
||||||
|
*/
|
||||||
public class ShareSecret {
|
public class ShareSecret {
|
||||||
|
private static final Logger LOGGER = Logger.getLogger(ShareSecret.class.toString());
|
||||||
private String flID;
|
private String flID;
|
||||||
private NewArray<byte[]> share;
|
private NewArray<byte[]> share;
|
||||||
private int index;
|
private int index;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get client's flID
|
||||||
|
*
|
||||||
|
* @return flID of this client
|
||||||
|
*/
|
||||||
public String getFlID() {
|
public String getFlID() {
|
||||||
|
if (flID == null || flID.isEmpty()) {
|
||||||
|
LOGGER.severe(Common.addTag("[ShareSecret] the parameter of <flID> is null, please set it before use"));
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
return flID;
|
return flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set flID for this client
|
||||||
|
*
|
||||||
|
* @param flID hash value used for identify client
|
||||||
|
*/
|
||||||
public void setFlID(String flID) {
|
public void setFlID(String flID) {
|
||||||
this.flID = flID;
|
this.flID = flID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get secret share
|
||||||
|
*
|
||||||
|
* @return secret share
|
||||||
|
*/
|
||||||
public NewArray<byte[]> getShare() {
|
public NewArray<byte[]> getShare() {
|
||||||
return share;
|
return share;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set secret share
|
||||||
|
*
|
||||||
|
* @param share secret shares
|
||||||
|
*/
|
||||||
public void setShare(NewArray<byte[]> share) {
|
public void setShare(NewArray<byte[]> share) {
|
||||||
this.share = share;
|
this.share = share;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get secret index
|
||||||
|
*
|
||||||
|
* @return secret index
|
||||||
|
*/
|
||||||
public int getIndex() {
|
public int getIndex() {
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set secret index
|
||||||
|
*
|
||||||
|
* @param index secret index
|
||||||
|
*/
|
||||||
public void setIndex(int index) {
|
public void setIndex(int index) {
|
||||||
this.index = index;
|
this.index = index;
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,8 @@ table ClientPublicKeys {
|
||||||
fl_id:string;
|
fl_id:string;
|
||||||
c_pk:[ubyte];
|
c_pk:[ubyte];
|
||||||
s_pk: [ubyte];
|
s_pk: [ubyte];
|
||||||
|
pw_iv: [ubyte];
|
||||||
|
pw_salt: [ubyte];
|
||||||
}
|
}
|
||||||
|
|
||||||
table ClientShare {
|
table ClientShare {
|
||||||
|
@ -45,6 +47,9 @@ table RequestExchangeKeys{
|
||||||
s_pk:[ubyte];
|
s_pk:[ubyte];
|
||||||
iteration:int;
|
iteration:int;
|
||||||
timestamp:string;
|
timestamp:string;
|
||||||
|
ind_iv:[ubyte];
|
||||||
|
pw_iv:[ubyte];
|
||||||
|
pw_salt:[ubyte];
|
||||||
}
|
}
|
||||||
|
|
||||||
table ResponseExchangeKeys{
|
table ResponseExchangeKeys{
|
||||||
|
|
Loading…
Reference in New Issue