forked from mindspore-Ecosystem/mindspore
Add exchangeKeys ops and getKeys ops for STABLE_PW_ENCRYPT
fix review syggestions
This commit is contained in:
parent
b2dfb595e7
commit
f972de90d0
|
@ -60,6 +60,8 @@ if(NOT ENABLE_CPU OR WIN32)
|
|||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/start_fl_job_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/update_model_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/push_metrics_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/get_keys_kernel.cc")
|
||||
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/exchange_keys_kernel.cc")
|
||||
endif()
|
||||
|
||||
if(ENABLE_SECURITY)
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/fl/exchange_keys_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int iv_vec_len = 16;
|
||||
constexpr int salt_len = 32;
|
||||
|
||||
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &) {
|
||||
MS_LOG(INFO) << "Launching client ExchangeKeysKernel";
|
||||
if (!BuildExchangeKeysReq(fbb_)) {
|
||||
MS_LOG(EXCEPTION) << "Building request for ExchangeKeys failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> exchange_keys_rsp_msg = nullptr;
|
||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(), fbb_->GetSize(),
|
||||
ps::core::TcpUserCommand::kExchangeKeys,
|
||||
&exchange_keys_rsp_msg)) {
|
||||
MS_LOG(EXCEPTION) << "Sending request for ExchangeKeys to server " << target_server_rank_ << " failed.";
|
||||
return false;
|
||||
}
|
||||
if (exchange_keys_rsp_msg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Received message pointer is nullptr.";
|
||||
return false;
|
||||
}
|
||||
flatbuffers::Verifier verifier(exchange_keys_rsp_msg->data(), exchange_keys_rsp_msg->size());
|
||||
if (!verifier.VerifyBuffer<schema::ResponseExchangeKeys>()) {
|
||||
MS_LOG(EXCEPTION) << "The schema of ResponseExchangeKeys is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const schema::ResponseExchangeKeys *exchange_keys_rsp =
|
||||
flatbuffers::GetRoot<schema::ResponseExchangeKeys>(exchange_keys_rsp_msg->data());
|
||||
MS_EXCEPTION_IF_NULL(exchange_keys_rsp);
|
||||
auto response_code = exchange_keys_rsp->retcode();
|
||||
if ((response_code != schema::ResponseCode_SUCCEED) && (response_code != schema::ResponseCode_OutOfTime)) {
|
||||
MS_LOG(EXCEPTION) << "Launching exchange keys job for worker failed. Reason: " << exchange_keys_rsp->reason();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Exchange keys successfully.";
|
||||
return true;
|
||||
}
|
||||
|
||||
void ExchangeKeysKernel::Init(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
fl_id_ = fl::worker::FLWorker::GetInstance().fl_id();
|
||||
server_num_ = fl::worker::FLWorker::GetInstance().server_num();
|
||||
rank_id_ = fl::worker::FLWorker::GetInstance().rank_id();
|
||||
if (rank_id_ == UINT32_MAX) {
|
||||
MS_LOG(EXCEPTION) << "Federated worker is not initialized yet.";
|
||||
return;
|
||||
}
|
||||
|
||||
if (server_num_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Server number should be larger than 0, but got: " << server_num_;
|
||||
return;
|
||||
}
|
||||
target_server_rank_ = rank_id_ % server_num_;
|
||||
|
||||
MS_LOG(INFO) << "Initializing ExchangeKeys kernel"
|
||||
<< ", fl_id: " << fl_id_ << ". Request will be sent to server " << target_server_rank_;
|
||||
|
||||
fbb_ = std::make_shared<fl::FBBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(fbb_);
|
||||
input_size_list_.push_back(sizeof(int));
|
||||
output_size_list_.push_back(sizeof(float));
|
||||
MS_LOG(INFO) << "Initialize ExchangeKeys kernel successfully.";
|
||||
}
|
||||
|
||||
void ExchangeKeysKernel::InitKernel(const CNodePtr &kernel_node) { return; }
|
||||
|
||||
bool ExchangeKeysKernel::BuildExchangeKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb) {
|
||||
MS_EXCEPTION_IF_NULL(fbb);
|
||||
// generate initialization vector value used for generate pairwise noise
|
||||
std::vector<uint8_t> pw_iv_(iv_vec_len);
|
||||
std::vector<uint8_t> pw_salt_(salt_len);
|
||||
auto ret = RAND_bytes(pw_iv_.data(), iv_vec_len);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "RAND_bytes error, failed to init pw_iv_.";
|
||||
return false;
|
||||
}
|
||||
// generate salt value used for generate pairwise noise
|
||||
ret = RAND_bytes(pw_salt_.data(), salt_len);
|
||||
if (ret != 1) {
|
||||
MS_LOG(ERROR) << "RAND_bytes error, failed to init pw_salt_.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// save pw_salt and pw_iv at local
|
||||
fl::worker::FLWorker::GetInstance().set_pw_salt(pw_salt_);
|
||||
fl::worker::FLWorker::GetInstance().set_pw_iv(pw_iv_);
|
||||
|
||||
// get public key bytes
|
||||
std::vector<uint8_t> pubkey_bytes = GetPubicKeyBytes();
|
||||
if (pubkey_bytes.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Get public key failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// build data which will be send to server
|
||||
int iter = fl::worker::FLWorker::GetInstance().fl_iteration_num();
|
||||
auto fbs_fl_id = fbb->CreateString(fl_id_);
|
||||
auto fbs_public_key = fbb->CreateVector(pubkey_bytes.data(), pubkey_bytes.size());
|
||||
auto fbs_pw_iv = fbb->CreateVector(pw_iv_.data(), iv_vec_len);
|
||||
auto fbs_pw_salt = fbb->CreateVector(pw_salt_.data(), salt_len);
|
||||
schema::RequestExchangeKeysBuilder req_exchange_key_builder(*(fbb.get()));
|
||||
req_exchange_key_builder.add_fl_id(fbs_fl_id);
|
||||
req_exchange_key_builder.add_s_pk(fbs_public_key);
|
||||
req_exchange_key_builder.add_iteration(iter);
|
||||
req_exchange_key_builder.add_pw_iv(fbs_pw_iv);
|
||||
req_exchange_key_builder.add_pw_salt(fbs_pw_salt);
|
||||
auto req_fl_job = req_exchange_key_builder.Finish();
|
||||
fbb->Finish(req_fl_job);
|
||||
MS_LOG(INFO) << "BuildExchangeKeysReq successfully.";
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<uint8_t> ExchangeKeysKernel::GetPubicKeyBytes() {
|
||||
// generate private key of secret
|
||||
armour::PrivateKey *sPriKeyPtr = armour::KeyAgreement::GeneratePrivKey();
|
||||
fl::worker::FLWorker::GetInstance().set_secret_pk(sPriKeyPtr);
|
||||
|
||||
// get public bytes length
|
||||
size_t pubLen;
|
||||
uint8_t *secret_pubkey_ptr = NULL;
|
||||
auto ret = sPriKeyPtr->GetPublicBytes(&pubLen, secret_pubkey_ptr);
|
||||
if (ret != 0 || pubLen == 0) {
|
||||
MS_LOG(ERROR) << "GetPublicBytes error, failed to get public_key bytes length.";
|
||||
return {};
|
||||
}
|
||||
// pubLen has been updated, now get public_key bytes
|
||||
secret_pubkey_ptr = reinterpret_cast<uint8_t *>(malloc(pubLen));
|
||||
ret = sPriKeyPtr->GetPublicBytes(&pubLen, secret_pubkey_ptr);
|
||||
if (ret != 0) {
|
||||
free(secret_pubkey_ptr);
|
||||
MS_LOG(ERROR) << "GetPublicBytes error, failed to get public_key bytes.";
|
||||
return {};
|
||||
}
|
||||
|
||||
// transform key buffer to uint8_t vector
|
||||
std::vector<uint8_t> pubkey_bytes(pubLen);
|
||||
for (int i = 0; i < SizeToInt(pubLen); i++) {
|
||||
pubkey_bytes[i] = secret_pubkey_ptr[i];
|
||||
}
|
||||
free(secret_pubkey_ptr);
|
||||
return pubkey_bytes;
|
||||
}
|
||||
|
||||
MS_REG_CPU_KERNEL(ExchangeKeys, KernelAttr().AddOutputAttr(kNumberTypeFloat32), ExchangeKeysKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_EXCHANGE_KEYS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_EXCHANGE_KEYS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "fl/worker/fl_worker.h"
|
||||
#include "fl/armour/secure_protocol/key_agreement.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ExchangeKeysKernel : public CPUKernel {
|
||||
public:
|
||||
ExchangeKeysKernel() = default;
|
||||
~ExchangeKeysKernel() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &) override;
|
||||
|
||||
void Init(const CNodePtr &kernel_node) override;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
private:
|
||||
bool BuildExchangeKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb);
|
||||
std::vector<uint8_t> GetPubicKeyBytes();
|
||||
|
||||
uint32_t rank_id_;
|
||||
uint32_t server_num_;
|
||||
uint32_t target_server_rank_;
|
||||
std::string fl_id_;
|
||||
std::shared_ptr<fl::FBBuilder> fbb_;
|
||||
armour::PrivateKey *secret_prikey_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_EXCHANGE_KEYS_H_
|
|
@ -0,0 +1,141 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/fl/get_keys_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &) {
|
||||
MS_LOG(INFO) << "Launching client GetKeysKernel";
|
||||
BuildGetKeysReq(fbb_);
|
||||
|
||||
std::shared_ptr<std::vector<unsigned char>> get_keys_rsp_msg = nullptr;
|
||||
if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(), fbb_->GetSize(),
|
||||
ps::core::TcpUserCommand::kGetKeys, &get_keys_rsp_msg)) {
|
||||
MS_LOG(EXCEPTION) << "Sending request for GetKeys to server " << target_server_rank_ << " failed.";
|
||||
return false;
|
||||
}
|
||||
if (get_keys_rsp_msg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Received message pointer is nullptr.";
|
||||
return false;
|
||||
}
|
||||
flatbuffers::Verifier verifier(get_keys_rsp_msg->data(), get_keys_rsp_msg->size());
|
||||
if (!verifier.VerifyBuffer<schema::ReturnExchangeKeys>()) {
|
||||
MS_LOG(EXCEPTION) << "The schema of ResponseGetKeys is invalid.";
|
||||
return false;
|
||||
}
|
||||
|
||||
const schema::ReturnExchangeKeys *get_keys_rsp =
|
||||
flatbuffers::GetRoot<schema::ReturnExchangeKeys>(get_keys_rsp_msg->data());
|
||||
MS_EXCEPTION_IF_NULL(get_keys_rsp);
|
||||
auto response_code = get_keys_rsp->retcode();
|
||||
if ((response_code != schema::ResponseCode_SUCCEED) && (response_code != schema::ResponseCode_OutOfTime)) {
|
||||
MS_LOG(EXCEPTION) << "Launching get keys job for worker failed. response_code: " << response_code;
|
||||
}
|
||||
|
||||
bool save_keys_succeed = SavePublicKeyList(get_keys_rsp->remote_publickeys());
|
||||
if (!save_keys_succeed) {
|
||||
MS_LOG(EXCEPTION) << "Save received remote keys failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Get keys successfully.";
|
||||
return true;
|
||||
}
|
||||
|
||||
void GetKeysKernel::Init(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
fl_id_ = fl::worker::FLWorker::GetInstance().fl_id();
|
||||
server_num_ = fl::worker::FLWorker::GetInstance().server_num();
|
||||
rank_id_ = fl::worker::FLWorker::GetInstance().rank_id();
|
||||
if (rank_id_ == UINT32_MAX) {
|
||||
MS_LOG(EXCEPTION) << "Federated worker is not initialized yet.";
|
||||
return;
|
||||
}
|
||||
if (server_num_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Server number should be larger than 0, but got: " << server_num_;
|
||||
return;
|
||||
}
|
||||
target_server_rank_ = rank_id_ % server_num_;
|
||||
|
||||
MS_LOG(INFO) << "Initializing GetKeys kernel"
|
||||
<< ", fl_id: " << fl_id_ << ". Request will be sent to server " << target_server_rank_;
|
||||
|
||||
fbb_ = std::make_shared<fl::FBBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(fbb_);
|
||||
input_size_list_.push_back(sizeof(int));
|
||||
output_size_list_.push_back(sizeof(float));
|
||||
MS_LOG(INFO) << "Initialize GetKeys kernel successfully.";
|
||||
}
|
||||
|
||||
void GetKeysKernel::InitKernel(const CNodePtr &kernel_node) { return; }
|
||||
|
||||
void GetKeysKernel::BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb) {
|
||||
MS_EXCEPTION_IF_NULL(fbb);
|
||||
int iter = fl::worker::FLWorker::GetInstance().fl_iteration_num();
|
||||
auto fbs_fl_id = fbb->CreateString(fl_id_);
|
||||
schema::GetExchangeKeysBuilder get_keys_builder(*(fbb.get()));
|
||||
get_keys_builder.add_fl_id(fbs_fl_id);
|
||||
get_keys_builder.add_iteration(iter);
|
||||
auto req_fl_job = get_keys_builder.Finish();
|
||||
fbb->Finish(req_fl_job);
|
||||
MS_LOG(INFO) << "BuildGetKeysReq successfully.";
|
||||
}
|
||||
|
||||
bool GetKeysKernel::SavePublicKeyList(auto remote_public_key) {
|
||||
if (remote_public_key == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Input remote_pubic_key is nullptr.";
|
||||
}
|
||||
|
||||
int client_num = remote_public_key->size();
|
||||
if (client_num <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Received client keys length is <= 0, please check it!";
|
||||
return false;
|
||||
}
|
||||
|
||||
// save client keys list
|
||||
std::vector<EncryptPublicKeys> saved_remote_public_keys;
|
||||
for (auto iter = remote_public_key->begin(); iter != remote_public_key->end(); ++iter) {
|
||||
std::string fl_id = iter->fl_id()->str();
|
||||
auto fbs_spk = iter->s_pk();
|
||||
auto fbs_pw_iv = iter->pw_iv();
|
||||
auto fbs_pw_salt = iter->pw_salt();
|
||||
if (fbs_spk == nullptr || fbs_pw_iv == nullptr || fbs_pw_salt == nullptr) {
|
||||
MS_LOG(WARNING) << "public key, pw_iv or pw_salt in remote_publickeys is nullptr.";
|
||||
} else {
|
||||
std::vector<uint8_t> spk_vector;
|
||||
std::vector<uint8_t> pw_iv_vector;
|
||||
std::vector<uint8_t> pw_salt_vector;
|
||||
spk_vector.assign(fbs_spk->begin(), fbs_spk->end());
|
||||
pw_iv_vector.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
|
||||
pw_salt_vector.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
|
||||
EncryptPublicKeys public_keys_i;
|
||||
public_keys_i.flID = fl_id;
|
||||
public_keys_i.publicKey = spk_vector;
|
||||
public_keys_i.pwIV = pw_iv_vector;
|
||||
public_keys_i.pwSalt = pw_salt_vector;
|
||||
saved_remote_public_keys.push_back(public_keys_i);
|
||||
MS_LOG(INFO) << "Add public keys of client:" << fl_id << " successfully.";
|
||||
}
|
||||
}
|
||||
fl::worker::FLWorker::GetInstance().set_public_keys_list(saved_remote_public_keys);
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_REG_CPU_KERNEL(GetKeys, KernelAttr().AddOutputAttr(kNumberTypeFloat32), GetKeysKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_KEYS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_KEYS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "fl/worker/fl_worker.h"
|
||||
#include "fl/armour/secure_protocol/key_agreement.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class GetKeysKernel : public CPUKernel {
|
||||
public:
|
||||
GetKeysKernel() = default;
|
||||
~GetKeysKernel() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &) override;
|
||||
|
||||
void Init(const CNodePtr &kernel_node) override;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
private:
|
||||
void BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb);
|
||||
bool SavePublicKeyList(auto remote_public_key);
|
||||
|
||||
uint32_t rank_id_;
|
||||
uint32_t server_num_;
|
||||
uint32_t target_server_rank_;
|
||||
std::string fl_id_;
|
||||
std::shared_ptr<fl::FBBuilder> fbb_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_KEYS_H_
|
|
@ -81,6 +81,11 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
|||
}
|
||||
MS_LOG(INFO) << " CipherInit::Init Success";
|
||||
}
|
||||
|
||||
if (param.encrypt_type == mindspore::ps::kStablePWEncryptType) {
|
||||
cipher_meta_storage_.RegisterStablePWClass();
|
||||
MS_LOG(INFO) << "Register metadata for StablePWEncrypt is finished.";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -34,7 +34,12 @@ bool CipherKeys::GetKeys(const size_t cur_iterator, const std::string &next_req_
|
|||
}
|
||||
// get clientlist from memory server.
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type == ps::kPWEncryptType) {
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||
} else {
|
||||
cipher_init_->cipher_meta_storage_.GetStableClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||
}
|
||||
|
||||
size_t cur_exchange_clients_num = client_public_keys.size();
|
||||
std::string fl_id = get_exchange_keys_req->fl_id()->str();
|
||||
|
@ -97,7 +102,13 @@ bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next
|
|||
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
|
||||
std::vector<std::string> client_list;
|
||||
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type == ps::kPWEncryptType) {
|
||||
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||
} else {
|
||||
cipher_init_->cipher_meta_storage_.GetStableClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
|
||||
}
|
||||
|
||||
// step2: process new item data. and update new item data to memory server.
|
||||
size_t cur_clients_num = client_list.size();
|
||||
|
@ -119,8 +130,15 @@ bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next
|
|||
return false;
|
||||
}
|
||||
|
||||
bool retcode_key =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
|
||||
bool retcode_key;
|
||||
if (encrypt_type == ps::kPWEncryptType) {
|
||||
retcode_key =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
|
||||
} else {
|
||||
retcode_key =
|
||||
cipher_init_->cipher_meta_storage_.UpdateStableClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
|
||||
}
|
||||
|
||||
bool retcode_client =
|
||||
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
|
||||
if (retcode_key && retcode_client) {
|
||||
|
@ -134,7 +152,7 @@ bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next
|
|||
cur_iterator);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace armour
|
||||
|
||||
void CipherKeys::BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb,
|
||||
const schema::ResponseCode retcode, const std::string &reason,
|
||||
|
@ -164,6 +182,7 @@ void CipherKeys::BuildGetKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &f
|
|||
fbb->Finish(rsp_get_keys);
|
||||
return;
|
||||
}
|
||||
|
||||
const fl::PBMetadata &clients_keys_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientsKeys);
|
||||
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||
|
@ -172,17 +191,27 @@ void CipherKeys::BuildGetKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &f
|
|||
std::string fl_id = iter->first;
|
||||
fl::KeysPb keys_pb = iter->second;
|
||||
auto fbs_fl_id = fbb->CreateString(fl_id);
|
||||
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
||||
auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size());
|
||||
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
|
||||
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
|
||||
auto fbs_pw_iv = fbb->CreateVector(pw_iv.data(), pw_iv.size());
|
||||
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
|
||||
auto fbs_pw_salt = fbb->CreateVector(pw_salt.data(), pw_salt.size());
|
||||
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
|
||||
public_keys_list.push_back(cur_public_key);
|
||||
|
||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||
if (encrypt_type == ps::kPWEncryptType) {
|
||||
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
|
||||
auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size());
|
||||
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
|
||||
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
|
||||
public_keys_list.push_back(cur_public_key);
|
||||
} else {
|
||||
std::vector<uint8_t> spk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
|
||||
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, 0, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
|
||||
public_keys_list.push_back(cur_public_key);
|
||||
}
|
||||
}
|
||||
|
||||
auto remote_publickeys = fbb->CreateVector(public_keys_list);
|
||||
auto fbs_next_req_time = fbb->CreateString(next_req_time);
|
||||
schema::ReturnExchangeKeysBuilder rsp_builder(*(fbb.get()));
|
||||
|
|
|
@ -81,6 +81,26 @@ void CipherMetaStorage::GetClientKeysFromServer(
|
|||
}
|
||||
}
|
||||
|
||||
void CipherMetaStorage::GetStableClientKeysFromServer(
|
||||
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) {
|
||||
if (clients_keys_list == nullptr) {
|
||||
MS_LOG(ERROR) << "Input clients_keys_list is nullptr";
|
||||
return;
|
||||
}
|
||||
const fl::PBMetadata &clients_keys_pb_out =
|
||||
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
|
||||
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
|
||||
|
||||
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
|
||||
std::string fl_id = iter->first;
|
||||
fl::KeysPb keys_pb = iter->second;
|
||||
std::vector<uint8_t> spk(keys_pb.key(0).begin(), keys_pb.key(0).end());
|
||||
std::vector<std::vector<uint8_t>> cur_keys;
|
||||
cur_keys.push_back(spk);
|
||||
(void)clients_keys_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
|
||||
}
|
||||
}
|
||||
|
||||
void CipherMetaStorage::GetClientIVsFromServer(
|
||||
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) {
|
||||
if (clients_ivs_list == nullptr) {
|
||||
|
@ -171,26 +191,6 @@ void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &
|
|||
(void)sleep(time);
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
||||
const std::vector<std::vector<uint8_t>> &cur_public_key) {
|
||||
const size_t correct_size = 2;
|
||||
if (cur_public_key.size() < correct_size) {
|
||||
MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
|
||||
return false;
|
||||
}
|
||||
// update new item to memory server.
|
||||
fl::KeysPb keys;
|
||||
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
|
||||
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
|
||||
fl::PairClientKeys pair_client_keys_pb;
|
||||
pair_client_keys_pb.set_fl_id(fl_id);
|
||||
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
||||
fl::PBMetadata client_and_keys_pb;
|
||||
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
||||
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||
return retcode;
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
|
||||
const schema::RequestExchangeKeys *exchange_keys_req) {
|
||||
std::string fl_id = exchange_keys_req->fl_id()->str();
|
||||
|
@ -257,6 +257,55 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
|
|||
return retcode;
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateStableClientKeyToServer(const char *list_name,
|
||||
const schema::RequestExchangeKeys *exchange_keys_req) {
|
||||
std::string fl_id = exchange_keys_req->fl_id()->str();
|
||||
auto fbs_spk = exchange_keys_req->s_pk();
|
||||
if (fbs_spk == nullptr) {
|
||||
MS_LOG(ERROR) << "Public key from exchange_keys_req is null";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t spk_len = fbs_spk->size();
|
||||
|
||||
// transform fbs_spk to a vector: public_key
|
||||
std::vector<uint8_t> spk(spk_len);
|
||||
bool ret_create_code_spk = CreateArray<uint8_t>(&spk, *fbs_spk);
|
||||
if (!ret_create_code_spk) {
|
||||
MS_LOG(ERROR) << "Create array for public keys failed";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto fbs_pw_iv = exchange_keys_req->pw_iv();
|
||||
std::vector<char> pw_iv;
|
||||
if (fbs_pw_iv == nullptr) {
|
||||
MS_LOG(WARNING) << "pw_iv in exchange_keys_req is nullptr";
|
||||
} else {
|
||||
pw_iv.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
|
||||
}
|
||||
|
||||
auto fbs_pw_salt = exchange_keys_req->pw_salt();
|
||||
std::vector<char> pw_salt;
|
||||
if (fbs_pw_salt == nullptr) {
|
||||
MS_LOG(WARNING) << "pw_salt in exchange_keys_req is nullptr";
|
||||
} else {
|
||||
pw_salt.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
|
||||
}
|
||||
|
||||
// update new item to memory server.
|
||||
fl::KeysPb keys;
|
||||
keys.add_key()->assign(spk.begin(), spk.end());
|
||||
keys.set_pw_iv(pw_iv.data(), pw_iv.size());
|
||||
keys.set_pw_salt(pw_salt.data(), pw_salt.size());
|
||||
fl::PairClientKeys pair_client_keys_pb;
|
||||
pair_client_keys_pb.set_fl_id(fl_id);
|
||||
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
|
||||
fl::PBMetadata client_and_keys_pb;
|
||||
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
|
||||
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
|
||||
return retcode;
|
||||
}
|
||||
|
||||
bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) {
|
||||
// update new item to memory server.
|
||||
fl::OneClientNoises noises_pb;
|
||||
|
@ -325,5 +374,16 @@ void CipherMetaStorage::RegisterClass() {
|
|||
fl::PBMetadata client_noises;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientNoises, client_noises);
|
||||
}
|
||||
|
||||
void CipherMetaStorage::RegisterStablePWClass() {
|
||||
fl::PBMetadata exchange_keys_client_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
|
||||
exchange_keys_client_list);
|
||||
fl::PBMetadata get_keys_client_list;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetKeysClientList,
|
||||
get_keys_client_list);
|
||||
fl::PBMetadata clients_keys;
|
||||
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
|
||||
}
|
||||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -74,6 +74,7 @@ class CipherMetaStorage {
|
|||
public:
|
||||
// Register the shared value involved in the security aggregation.
|
||||
void RegisterClass();
|
||||
void RegisterStablePWClass();
|
||||
|
||||
// Register Prime.
|
||||
void RegisterPrime(const char *list_name, const std::string &prime);
|
||||
|
@ -87,17 +88,19 @@ class CipherMetaStorage {
|
|||
// Get client keys from shared server.
|
||||
void GetClientKeysFromServer(const char *list_name,
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list);
|
||||
// Get stable secure aggregation's client key from shared server.
|
||||
void GetStableClientKeysFromServer(const char *list_name,
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list);
|
||||
void GetClientIVsFromServer(const char *list_name,
|
||||
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list);
|
||||
// Get client noises from shared server.
|
||||
bool GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise);
|
||||
// Update client fl_id to shared server.
|
||||
bool UpdateClientToServer(const char *list_name, const std::string &fl_id);
|
||||
// Update client key to shared server.
|
||||
bool UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
|
||||
const std::vector<std::vector<uint8_t>> &cur_public_key);
|
||||
// Update client key with signature to shared server.
|
||||
bool UpdateClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req);
|
||||
// Update stable secure aggregation's client key to shared server.
|
||||
bool UpdateStableClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req);
|
||||
// Update client noise to shared server.
|
||||
bool UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise);
|
||||
// Update client share to shared server.
|
||||
|
|
|
@ -66,7 +66,7 @@ int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) const {
|
|||
}
|
||||
|
||||
int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) const {
|
||||
if (pubKeyBytes == nullptr || len == nullptr || evpPrivKey == nullptr) {
|
||||
if (evpPrivKey == nullptr) {
|
||||
MS_LOG(ERROR) << "input pubKeyBytes invalid.";
|
||||
return -1;
|
||||
}
|
||||
|
|
|
@ -125,7 +125,10 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
|||
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||
return {};
|
||||
}
|
||||
|
||||
if (metadata_.count(name) == 0) {
|
||||
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
|
||||
return {};
|
||||
}
|
||||
uint32_t stored_rank = router_->Find(name);
|
||||
MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
|
||||
if (local_rank_ == stored_rank) {
|
||||
|
@ -257,42 +260,15 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
|
|||
} else if (meta.has_prime()) {
|
||||
metadata_[name] = meta;
|
||||
} else if (meta.has_pair_client_keys()) {
|
||||
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
|
||||
auto &fl_id = meta.pair_client_keys().fl_id();
|
||||
auto &client_keys = meta.pair_client_keys().client_keys();
|
||||
// Check whether the new item already exists.
|
||||
bool add_flag = true;
|
||||
for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) {
|
||||
if (fl_id == iter->first) {
|
||||
add_flag = false;
|
||||
MS_LOG(ERROR) << "Leader server updating value for " << name
|
||||
<< " failed: The Protobuffer of this value already exists.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (add_flag) {
|
||||
client_keys_map[fl_id] = client_keys;
|
||||
} else {
|
||||
bool keys_update_succeed = UpdatePairClientKeys(name, meta);
|
||||
if (!keys_update_succeed) {
|
||||
MS_LOG(ERROR) << "Update pair_client_keys failed.";
|
||||
return false;
|
||||
}
|
||||
} else if (meta.has_pair_client_shares()) {
|
||||
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
|
||||
auto &fl_id = meta.pair_client_shares().fl_id();
|
||||
auto &client_shares = meta.pair_client_shares().client_shares();
|
||||
// google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
|
||||
// Check whether the new item already exists.
|
||||
bool add_flag = true;
|
||||
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) {
|
||||
if (fl_id == iter->first) {
|
||||
add_flag = false;
|
||||
MS_LOG(ERROR) << "Leader server updating value for " << name
|
||||
<< " failed: The Protobuffer of this value already exists.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (add_flag) {
|
||||
client_shares_map[fl_id] = client_shares;
|
||||
} else {
|
||||
bool shares_update_succeed = UpdatePairClientShares(name, meta);
|
||||
if (!shares_update_succeed) {
|
||||
MS_LOG(ERROR) << "Update pair_client_shares failed.";
|
||||
return false;
|
||||
}
|
||||
} else if (meta.has_one_client_noises()) {
|
||||
|
@ -310,6 +286,51 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::UpdatePairClientKeys(const std::string &name, const PBMetadata &meta) {
|
||||
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
|
||||
auto &fl_id = meta.pair_client_keys().fl_id();
|
||||
auto &client_keys = meta.pair_client_keys().client_keys();
|
||||
// Check whether the new item already exists.
|
||||
bool add_flag = true;
|
||||
for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) {
|
||||
if (fl_id == iter->first) {
|
||||
add_flag = false;
|
||||
MS_LOG(ERROR) << "Leader server updating value for " << name
|
||||
<< " failed: The Protobuffer of this value already exists.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (add_flag) {
|
||||
client_keys_map[fl_id] = client_keys;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool DistributedMetadataStore::UpdatePairClientShares(const std::string &name, const PBMetadata &meta) {
|
||||
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
|
||||
auto &fl_id = meta.pair_client_shares().fl_id();
|
||||
auto &client_shares = meta.pair_client_shares().client_shares();
|
||||
// google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
|
||||
// Check whether the new item already exists.
|
||||
bool add_flag = true;
|
||||
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) {
|
||||
if (fl_id == iter->first) {
|
||||
add_flag = false;
|
||||
MS_LOG(ERROR) << "Leader server updating value for " << name
|
||||
<< " failed: The Protobuffer of this value already exists.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (add_flag) {
|
||||
client_shares_map[fl_id] = client_shares;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace server
|
||||
} // namespace fl
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -88,6 +88,12 @@ class DistributedMetadataStore {
|
|||
// Do updating metadata in the server where the metadata for the name is stored.
|
||||
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
// Update client keys stored in server
|
||||
bool UpdatePairClientKeys(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
// Update client shares stored in server
|
||||
bool UpdatePairClientShares(const std::string &name, const PBMetadata &meta);
|
||||
|
||||
// Members for the communication between servers.
|
||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||
std::shared_ptr<ps::core::TcpCommunicator> communicator_;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include <utility>
|
||||
#include "fl/worker/fl_worker.h"
|
||||
#include "fl/armour/secure_protocol/key_agreement.h"
|
||||
#include "utils/ms_exception.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -173,8 +174,26 @@ uint64_t FLWorker::fl_iteration_num() const { return iteration_num_.load(); }
|
|||
|
||||
void FLWorker::set_data_size(int data_size) { data_size_ = data_size; }
|
||||
|
||||
void FLWorker::set_secret_pk(armour::PrivateKey *secret_pk) { secret_pk_ = secret_pk; }
|
||||
|
||||
void FLWorker::set_pw_salt(std::vector<uint8_t> pw_salt) { pw_salt_ = pw_salt; }
|
||||
|
||||
void FLWorker::set_pw_iv(std::vector<uint8_t> pw_iv) { pw_iv_ = pw_iv; }
|
||||
|
||||
void FLWorker::set_public_keys_list(std::vector<EncryptPublicKeys> public_keys_list) {
|
||||
public_keys_list_ = public_keys_list;
|
||||
}
|
||||
|
||||
int FLWorker::data_size() const { return data_size_; }
|
||||
|
||||
armour::PrivateKey *FLWorker::secret_pk() const { return secret_pk_; }
|
||||
|
||||
std::vector<uint8_t> FLWorker::pw_salt() const { return pw_salt_; }
|
||||
|
||||
std::vector<uint8_t> FLWorker::pw_iv() const { return pw_iv_; }
|
||||
|
||||
std::vector<EncryptPublicKeys> FLWorker::public_keys_list() const { return public_keys_list_; }
|
||||
|
||||
std::string FLWorker::fl_name() const { return ps::kServerModeFL; }
|
||||
|
||||
std::string FLWorker::fl_id() const { return std::to_string(rank_id_); }
|
||||
|
|
|
@ -23,11 +23,19 @@
|
|||
#include "proto/comm.pb.h"
|
||||
#include "schema/fl_job_generated.h"
|
||||
#include "schema/cipher_generated.h"
|
||||
#include "fl/armour/secure_protocol/key_agreement.h"
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/worker_node.h"
|
||||
#include "ps/core/cluster_metadata.h"
|
||||
#include "ps/core/communicator/tcp_communicator.h"
|
||||
|
||||
struct EncryptPublicKeys {
|
||||
std::string flID;
|
||||
std::vector<uint8_t> publicKey;
|
||||
std::vector<uint8_t> pwIV;
|
||||
std::vector<uint8_t> pwSalt;
|
||||
};
|
||||
|
||||
namespace mindspore {
|
||||
namespace fl {
|
||||
using FBBuilder = flatbuffers::FlatBufferBuilder;
|
||||
|
@ -88,6 +96,18 @@ class FLWorker {
|
|||
void set_data_size(int data_size);
|
||||
int data_size() const;
|
||||
|
||||
void set_secret_pk(armour::PrivateKey *secret_pk);
|
||||
armour::PrivateKey *secret_pk() const;
|
||||
|
||||
void set_pw_salt(std::vector<uint8_t> pw_salt);
|
||||
std::vector<uint8_t> pw_salt() const;
|
||||
|
||||
void set_pw_iv(std::vector<uint8_t> pw_iv);
|
||||
std::vector<uint8_t> pw_iv() const;
|
||||
|
||||
void set_public_keys_list(std::vector<EncryptPublicKeys> public_keys_list);
|
||||
std::vector<EncryptPublicKeys> public_keys_list() const;
|
||||
|
||||
std::string fl_name() const;
|
||||
std::string fl_id() const;
|
||||
|
||||
|
@ -105,7 +125,11 @@ class FLWorker {
|
|||
worker_step_num_per_iteration_(1),
|
||||
server_iteration_state_(IterationState::kCompleted),
|
||||
worker_iteration_state_(IterationState::kCompleted),
|
||||
safemode_(false) {}
|
||||
safemode_(false),
|
||||
secret_pk_(nullptr),
|
||||
pw_salt_({}),
|
||||
pw_iv_({}),
|
||||
public_keys_list_({}) {}
|
||||
~FLWorker() = default;
|
||||
FLWorker(const FLWorker &) = delete;
|
||||
FLWorker &operator=(const FLWorker &) = delete;
|
||||
|
@ -152,6 +176,18 @@ class FLWorker {
|
|||
|
||||
// The flag that represents whether worker is in safemode, which is decided by both worker and server iteration state.
|
||||
std::atomic_bool safemode_;
|
||||
|
||||
// The private key used for computing the pairwise encryption's secret.
|
||||
armour::PrivateKey *secret_pk_;
|
||||
|
||||
// The salt value used for generate pairwise noise.
|
||||
std::vector<uint8_t> pw_salt_;
|
||||
|
||||
// The initialization vector value used for generate pairwise noise.
|
||||
std::vector<uint8_t> pw_iv_;
|
||||
|
||||
// The public keys used for computing the pairwise encryption's secret.
|
||||
std::vector<EncryptPublicKeys> public_keys_list_;
|
||||
};
|
||||
} // namespace worker
|
||||
} // namespace fl
|
||||
|
|
|
@ -890,6 +890,11 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
|
||||
rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
|
||||
}
|
||||
if (encrypt_type == ps::kStablePWEncryptType) {
|
||||
MS_LOG(INFO) << "Add stable secure aggregation rounds.";
|
||||
rounds_config.push_back({"exchangeKeys", true, cipher_time_window, true, exchange_keys_threshold});
|
||||
rounds_config.push_back({"getKeys", true, cipher_time_window, true, get_keys_threshold});
|
||||
}
|
||||
#endif
|
||||
fl::server::CipherConfig cipher_config = {
|
||||
share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,
|
||||
|
|
|
@ -56,7 +56,9 @@ enum class TcpUserCommand {
|
|||
kNewInstance,
|
||||
kQueryInstance,
|
||||
kEnableFLS,
|
||||
kDisableFLS
|
||||
kDisableFLS,
|
||||
kExchangeKeys,
|
||||
kGetKeys
|
||||
};
|
||||
|
||||
// CommunicatorBase is used to receive request and send response for server.
|
||||
|
|
|
@ -53,6 +53,8 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
|||
{TcpUserCommand::kProceedToNextIter, "proceedToNextIter"},
|
||||
{TcpUserCommand::kEndLastIter, "endLastIter"},
|
||||
{TcpUserCommand::kStartFLJob, "startFLJob"},
|
||||
{TcpUserCommand::kExchangeKeys, "exchangeKeys"},
|
||||
{TcpUserCommand::kGetKeys, "getKeys"},
|
||||
{TcpUserCommand::kUpdateModel, "updateModel"},
|
||||
{TcpUserCommand::kGetModel, "getModel"},
|
||||
{TcpUserCommand::kPushMetrics, "pushMetrics"},
|
||||
|
|
|
@ -199,9 +199,10 @@ void PSContext::set_server_mode(const std::string &server_mode) {
|
|||
const std::string &PSContext::server_mode() const { return server_mode_; }
|
||||
|
||||
void PSContext::set_encrypt_type(const std::string &encrypt_type) {
|
||||
if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) {
|
||||
if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType &&
|
||||
encrypt_type != kStablePWEncryptType) {
|
||||
MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
|
||||
<< kDPEncryptType << " or " << kPWEncryptType;
|
||||
<< kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType;
|
||||
return;
|
||||
}
|
||||
encrypt_type_ = encrypt_type;
|
||||
|
|
|
@ -37,6 +37,7 @@ constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
|||
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
||||
constexpr char kDPEncryptType[] = "DP_ENCRYPT";
|
||||
constexpr char kPWEncryptType[] = "PW_ENCRYPT";
|
||||
constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT";
|
||||
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
|
||||
|
||||
// Use binary data to represent federated learning server's context so that we can judge which round resets the
|
||||
|
|
|
@ -985,11 +985,12 @@ def set_fl_context(**kwargs):
|
|||
client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01.
|
||||
dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is
|
||||
suggested to be 0.5~2. Default: 1.0.
|
||||
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT' or
|
||||
'PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy
|
||||
protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If
|
||||
'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing.
|
||||
Default: 'NOT_ENCRYPT'.
|
||||
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT',
|
||||
'PW_ENCRYPT' or 'STABLE_PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied
|
||||
for clients and the privacy protection effect would be determined by dp_eps, dp_delta and dp_norm_clip
|
||||
as described above. If 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients'
|
||||
model from stealing in cross-device scenario. If 'STABLE_PW_ENCRYPT', pairwise secure aggregation would
|
||||
be applied to protect clients' model from stealing in cross-silo scenario. Default: 'NOT_ENCRYPT'.
|
||||
config_file_path (string): Configuration file path used by recovery. Default: ''.
|
||||
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
|
||||
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.
|
||||
|
|
|
@ -92,7 +92,7 @@ from ._quant_ops import *
|
|||
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
ConfusionMatrix, PopulationCount, UpdateState, Load,
|
||||
CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight,
|
||||
PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc)
|
||||
PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc, ExchangeKeys, GetKeys)
|
||||
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
|
||||
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
|
||||
CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col,
|
||||
|
|
|
@ -885,6 +885,40 @@ class GetModel(PrimitiveWithInfer):
|
|||
return mstype.float32
|
||||
|
||||
|
||||
class ExchangeKeys(PrimitiveWithInfer):
|
||||
"""
|
||||
Exchange pairwise public keys for federated learning worker.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.init_prim_io_names(inputs=[], outputs=["result"])
|
||||
|
||||
def infer_shape(self):
|
||||
return [1]
|
||||
|
||||
def infer_dtype(self):
|
||||
return mstype.float32
|
||||
|
||||
|
||||
class GetKeys(PrimitiveWithInfer):
|
||||
"""
|
||||
Get pairwise public keys for federated learning worker.
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
self.add_prim_attr('side_effect_mem', True)
|
||||
self.init_prim_io_names(inputs=[], outputs=["result"])
|
||||
|
||||
def infer_shape(self):
|
||||
return [1]
|
||||
|
||||
def infer_dtype(self):
|
||||
return mstype.float32
|
||||
|
||||
|
||||
class identity(Primitive):
|
||||
"""
|
||||
Makes a identify primitive, used for pynative mode.
|
||||
|
|
|
@ -31,6 +31,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
|||
parser.add_argument("--local_worker_num", type=int, default=-1)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
device_target = args.device_target
|
||||
|
@ -47,6 +48,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
|||
local_worker_num = args.local_worker_num
|
||||
config_file_path = args.config_file_path
|
||||
dataset_path = args.dataset_path
|
||||
encrypt_type = args.encrypt_type
|
||||
|
||||
if local_worker_num == -1:
|
||||
local_worker_num = worker_num
|
||||
|
@ -73,6 +75,7 @@ for i in range(local_worker_num):
|
|||
cmd_worker += " --client_learning_rate=" + str(client_learning_rate)
|
||||
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
|
||||
cmd_worker += " --dataset_path=" + str(dataset_path)
|
||||
cmd_worker += " --encrypt_type=" + str(encrypt_type)
|
||||
cmd_worker += " --user_id=" + str(i)
|
||||
cmd_worker += " > worker.log 2>&1 &"
|
||||
|
||||
|
|
|
@ -58,6 +58,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
|
|||
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
|
||||
parser.add_argument("--config_file_path", type=str, default="")
|
||||
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
|
||||
parser.add_argument("--cipher_time_window", type=int, default=300000)
|
||||
parser.add_argument("--dataset_path", type=str, default="")
|
||||
# The user_id is used to set each worker's dataset path.
|
||||
parser.add_argument("--user_id", type=str, default="0")
|
||||
|
@ -87,6 +88,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
|
|||
scheduler_manage_port = args.scheduler_manage_port
|
||||
config_file_path = args.config_file_path
|
||||
encrypt_type = args.encrypt_type
|
||||
cipher_time_window = args.cipher_time_window
|
||||
dataset_path = args.dataset_path
|
||||
user_id = args.user_id
|
||||
|
||||
|
@ -111,7 +113,8 @@ ctx = {
|
|||
"worker_step_num_per_iteration": worker_step_num_per_iteration,
|
||||
"scheduler_manage_port": scheduler_manage_port,
|
||||
"config_file_path": config_file_path,
|
||||
"encrypt_type": encrypt_type
|
||||
"encrypt_type": encrypt_type,
|
||||
"cipher_time_window": cipher_time_window,
|
||||
}
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
|
@ -285,6 +288,24 @@ class StartFLJob(nn.Cell):
|
|||
return self.start_fl_job()
|
||||
|
||||
|
||||
class ExchangeKeys(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ExchangeKeys, self).__init__()
|
||||
self.exchange_keys = P.ExchangeKeys()
|
||||
|
||||
def construct(self):
|
||||
return self.exchange_keys()
|
||||
|
||||
|
||||
class GetKeys(nn.Cell):
|
||||
def __init__(self):
|
||||
super(GetKeys, self).__init__()
|
||||
self.get_keys = P.GetKeys()
|
||||
|
||||
def construct(self):
|
||||
return self.get_keys()
|
||||
|
||||
|
||||
class UpdateAndGetModel(nn.Cell):
|
||||
def __init__(self, weights):
|
||||
super(UpdateAndGetModel, self).__init__()
|
||||
|
@ -328,6 +349,11 @@ def train():
|
|||
if context.get_fl_context("ms_role") == "MS_WORKER":
|
||||
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size)
|
||||
start_fl_job()
|
||||
if encrypt_type == "STABLE_PW_ENCRYPT":
|
||||
exchange_keys = ExchangeKeys()
|
||||
exchange_keys()
|
||||
get_keys = GetKeys()
|
||||
get_keys()
|
||||
|
||||
for _ in range(epoch):
|
||||
print("step is ", epoch, flush=True)
|
||||
|
|
Loading…
Reference in New Issue