Add exchangeKeys ops and getKeys ops for STABLE_PW_ENCRYPT

fix review syggestions
This commit is contained in:
jin-xiulang 2021-11-16 17:26:15 +08:00
parent b2dfb595e7
commit f972de90d0
24 changed files with 756 additions and 80 deletions

View File

@ -60,6 +60,8 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/start_fl_job_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/update_model_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/push_metrics_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/get_keys_kernel.cc")
list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/exchange_keys_kernel.cc")
endif()
if(ENABLE_SECURITY)

View File

@ -0,0 +1,168 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/fl/exchange_keys_kernel.h"
namespace mindspore {
namespace kernel {
constexpr int iv_vec_len = 16;
constexpr int salt_len = 32;
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) {
MS_LOG(INFO) << "Launching client ExchangeKeysKernel";
if (!BuildExchangeKeysReq(fbb_)) {
MS_LOG(EXCEPTION) << "Building request for ExchangeKeys failed.";
return false;
}
std::shared_ptr<std::vector<unsigned char>> exchange_keys_rsp_msg = nullptr;
if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(), fbb_->GetSize(),
ps::core::TcpUserCommand::kExchangeKeys,
&exchange_keys_rsp_msg)) {
MS_LOG(EXCEPTION) << "Sending request for ExchangeKeys to server " << target_server_rank_ << " failed.";
return false;
}
if (exchange_keys_rsp_msg == nullptr) {
MS_LOG(EXCEPTION) << "Received message pointer is nullptr.";
return false;
}
flatbuffers::Verifier verifier(exchange_keys_rsp_msg->data(), exchange_keys_rsp_msg->size());
if (!verifier.VerifyBuffer<schema::ResponseExchangeKeys>()) {
MS_LOG(EXCEPTION) << "The schema of ResponseExchangeKeys is invalid.";
return false;
}
const schema::ResponseExchangeKeys *exchange_keys_rsp =
flatbuffers::GetRoot<schema::ResponseExchangeKeys>(exchange_keys_rsp_msg->data());
MS_EXCEPTION_IF_NULL(exchange_keys_rsp);
auto response_code = exchange_keys_rsp->retcode();
if ((response_code != schema::ResponseCode_SUCCEED) && (response_code != schema::ResponseCode_OutOfTime)) {
MS_LOG(EXCEPTION) << "Launching exchange keys job for worker failed. Reason: " << exchange_keys_rsp->reason();
}
MS_LOG(INFO) << "Exchange keys successfully.";
return true;
}
void ExchangeKeysKernel::Init(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
fl_id_ = fl::worker::FLWorker::GetInstance().fl_id();
server_num_ = fl::worker::FLWorker::GetInstance().server_num();
rank_id_ = fl::worker::FLWorker::GetInstance().rank_id();
if (rank_id_ == UINT32_MAX) {
MS_LOG(EXCEPTION) << "Federated worker is not initialized yet.";
return;
}
if (server_num_ <= 0) {
MS_LOG(EXCEPTION) << "Server number should be larger than 0, but got: " << server_num_;
return;
}
target_server_rank_ = rank_id_ % server_num_;
MS_LOG(INFO) << "Initializing ExchangeKeys kernel"
<< ", fl_id: " << fl_id_ << ". Request will be sent to server " << target_server_rank_;
fbb_ = std::make_shared<fl::FBBuilder>();
MS_EXCEPTION_IF_NULL(fbb_);
input_size_list_.push_back(sizeof(int));
output_size_list_.push_back(sizeof(float));
MS_LOG(INFO) << "Initialize ExchangeKeys kernel successfully.";
}
void ExchangeKeysKernel::InitKernel(const CNodePtr &kernel_node) { return; }
bool ExchangeKeysKernel::BuildExchangeKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb) {
MS_EXCEPTION_IF_NULL(fbb);
// generate initialization vector value used for generate pairwise noise
std::vector<uint8_t> pw_iv_(iv_vec_len);
std::vector<uint8_t> pw_salt_(salt_len);
auto ret = RAND_bytes(pw_iv_.data(), iv_vec_len);
if (ret != 1) {
MS_LOG(ERROR) << "RAND_bytes error, failed to init pw_iv_.";
return false;
}
// generate salt value used for generate pairwise noise
ret = RAND_bytes(pw_salt_.data(), salt_len);
if (ret != 1) {
MS_LOG(ERROR) << "RAND_bytes error, failed to init pw_salt_.";
return false;
}
// save pw_salt and pw_iv at local
fl::worker::FLWorker::GetInstance().set_pw_salt(pw_salt_);
fl::worker::FLWorker::GetInstance().set_pw_iv(pw_iv_);
// get public key bytes
std::vector<uint8_t> pubkey_bytes = GetPubicKeyBytes();
if (pubkey_bytes.size() == 0) {
MS_LOG(EXCEPTION) << "Get public key failed.";
return false;
}
// build data which will be send to server
int iter = fl::worker::FLWorker::GetInstance().fl_iteration_num();
auto fbs_fl_id = fbb->CreateString(fl_id_);
auto fbs_public_key = fbb->CreateVector(pubkey_bytes.data(), pubkey_bytes.size());
auto fbs_pw_iv = fbb->CreateVector(pw_iv_.data(), iv_vec_len);
auto fbs_pw_salt = fbb->CreateVector(pw_salt_.data(), salt_len);
schema::RequestExchangeKeysBuilder req_exchange_key_builder(*(fbb.get()));
req_exchange_key_builder.add_fl_id(fbs_fl_id);
req_exchange_key_builder.add_s_pk(fbs_public_key);
req_exchange_key_builder.add_iteration(iter);
req_exchange_key_builder.add_pw_iv(fbs_pw_iv);
req_exchange_key_builder.add_pw_salt(fbs_pw_salt);
auto req_fl_job = req_exchange_key_builder.Finish();
fbb->Finish(req_fl_job);
MS_LOG(INFO) << "BuildExchangeKeysReq successfully.";
return true;
}
std::vector<uint8_t> ExchangeKeysKernel::GetPubicKeyBytes() {
// generate private key of secret
armour::PrivateKey *sPriKeyPtr = armour::KeyAgreement::GeneratePrivKey();
fl::worker::FLWorker::GetInstance().set_secret_pk(sPriKeyPtr);
// get public bytes length
size_t pubLen;
uint8_t *secret_pubkey_ptr = NULL;
auto ret = sPriKeyPtr->GetPublicBytes(&pubLen, secret_pubkey_ptr);
if (ret != 0 || pubLen == 0) {
MS_LOG(ERROR) << "GetPublicBytes error, failed to get public_key bytes length.";
return {};
}
// pubLen has been updated, now get public_key bytes
secret_pubkey_ptr = reinterpret_cast<uint8_t *>(malloc(pubLen));
ret = sPriKeyPtr->GetPublicBytes(&pubLen, secret_pubkey_ptr);
if (ret != 0) {
free(secret_pubkey_ptr);
MS_LOG(ERROR) << "GetPublicBytes error, failed to get public_key bytes.";
return {};
}
// transform key buffer to uint8_t vector
std::vector<uint8_t> pubkey_bytes(pubLen);
for (int i = 0; i < SizeToInt(pubLen); i++) {
pubkey_bytes[i] = secret_pubkey_ptr[i];
}
free(secret_pubkey_ptr);
return pubkey_bytes;
}
MS_REG_CPU_KERNEL(ExchangeKeys, KernelAttr().AddOutputAttr(kNumberTypeFloat32), ExchangeKeysKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_EXCHANGE_KEYS_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_EXCHANGE_KEYS_H_
#include <vector>
#include <string>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "fl/worker/fl_worker.h"
#include "fl/armour/secure_protocol/key_agreement.h"
namespace mindspore {
namespace kernel {
class ExchangeKeysKernel : public CPUKernel {
public:
ExchangeKeysKernel() = default;
~ExchangeKeysKernel() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) override;
void Init(const CNodePtr &kernel_node) override;
void InitKernel(const CNodePtr &kernel_node) override;
private:
bool BuildExchangeKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb);
std::vector<uint8_t> GetPubicKeyBytes();
uint32_t rank_id_;
uint32_t server_num_;
uint32_t target_server_rank_;
std::string fl_id_;
std::shared_ptr<fl::FBBuilder> fbb_;
armour::PrivateKey *secret_prikey_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_EXCHANGE_KEYS_H_

View File

@ -0,0 +1,141 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/fl/get_keys_kernel.h"
namespace mindspore {
namespace kernel {
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) {
MS_LOG(INFO) << "Launching client GetKeysKernel";
BuildGetKeysReq(fbb_);
std::shared_ptr<std::vector<unsigned char>> get_keys_rsp_msg = nullptr;
if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(), fbb_->GetSize(),
ps::core::TcpUserCommand::kGetKeys, &get_keys_rsp_msg)) {
MS_LOG(EXCEPTION) << "Sending request for GetKeys to server " << target_server_rank_ << " failed.";
return false;
}
if (get_keys_rsp_msg == nullptr) {
MS_LOG(EXCEPTION) << "Received message pointer is nullptr.";
return false;
}
flatbuffers::Verifier verifier(get_keys_rsp_msg->data(), get_keys_rsp_msg->size());
if (!verifier.VerifyBuffer<schema::ReturnExchangeKeys>()) {
MS_LOG(EXCEPTION) << "The schema of ResponseGetKeys is invalid.";
return false;
}
const schema::ReturnExchangeKeys *get_keys_rsp =
flatbuffers::GetRoot<schema::ReturnExchangeKeys>(get_keys_rsp_msg->data());
MS_EXCEPTION_IF_NULL(get_keys_rsp);
auto response_code = get_keys_rsp->retcode();
if ((response_code != schema::ResponseCode_SUCCEED) && (response_code != schema::ResponseCode_OutOfTime)) {
MS_LOG(EXCEPTION) << "Launching get keys job for worker failed. response_code: " << response_code;
}
bool save_keys_succeed = SavePublicKeyList(get_keys_rsp->remote_publickeys());
if (!save_keys_succeed) {
MS_LOG(EXCEPTION) << "Save received remote keys failed.";
return false;
}
MS_LOG(INFO) << "Get keys successfully.";
return true;
}
void GetKeysKernel::Init(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
fl_id_ = fl::worker::FLWorker::GetInstance().fl_id();
server_num_ = fl::worker::FLWorker::GetInstance().server_num();
rank_id_ = fl::worker::FLWorker::GetInstance().rank_id();
if (rank_id_ == UINT32_MAX) {
MS_LOG(EXCEPTION) << "Federated worker is not initialized yet.";
return;
}
if (server_num_ <= 0) {
MS_LOG(EXCEPTION) << "Server number should be larger than 0, but got: " << server_num_;
return;
}
target_server_rank_ = rank_id_ % server_num_;
MS_LOG(INFO) << "Initializing GetKeys kernel"
<< ", fl_id: " << fl_id_ << ". Request will be sent to server " << target_server_rank_;
fbb_ = std::make_shared<fl::FBBuilder>();
MS_EXCEPTION_IF_NULL(fbb_);
input_size_list_.push_back(sizeof(int));
output_size_list_.push_back(sizeof(float));
MS_LOG(INFO) << "Initialize GetKeys kernel successfully.";
}
void GetKeysKernel::InitKernel(const CNodePtr &kernel_node) { return; }
void GetKeysKernel::BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb) {
MS_EXCEPTION_IF_NULL(fbb);
int iter = fl::worker::FLWorker::GetInstance().fl_iteration_num();
auto fbs_fl_id = fbb->CreateString(fl_id_);
schema::GetExchangeKeysBuilder get_keys_builder(*(fbb.get()));
get_keys_builder.add_fl_id(fbs_fl_id);
get_keys_builder.add_iteration(iter);
auto req_fl_job = get_keys_builder.Finish();
fbb->Finish(req_fl_job);
MS_LOG(INFO) << "BuildGetKeysReq successfully.";
}
bool GetKeysKernel::SavePublicKeyList(auto remote_public_key) {
if (remote_public_key == nullptr) {
MS_LOG(EXCEPTION) << "Input remote_pubic_key is nullptr.";
}
int client_num = remote_public_key->size();
if (client_num <= 0) {
MS_LOG(EXCEPTION) << "Received client keys length is <= 0, please check it!";
return false;
}
// save client keys list
std::vector<EncryptPublicKeys> saved_remote_public_keys;
for (auto iter = remote_public_key->begin(); iter != remote_public_key->end(); ++iter) {
std::string fl_id = iter->fl_id()->str();
auto fbs_spk = iter->s_pk();
auto fbs_pw_iv = iter->pw_iv();
auto fbs_pw_salt = iter->pw_salt();
if (fbs_spk == nullptr || fbs_pw_iv == nullptr || fbs_pw_salt == nullptr) {
MS_LOG(WARNING) << "public key, pw_iv or pw_salt in remote_publickeys is nullptr.";
} else {
std::vector<uint8_t> spk_vector;
std::vector<uint8_t> pw_iv_vector;
std::vector<uint8_t> pw_salt_vector;
spk_vector.assign(fbs_spk->begin(), fbs_spk->end());
pw_iv_vector.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
pw_salt_vector.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
EncryptPublicKeys public_keys_i;
public_keys_i.flID = fl_id;
public_keys_i.publicKey = spk_vector;
public_keys_i.pwIV = pw_iv_vector;
public_keys_i.pwSalt = pw_salt_vector;
saved_remote_public_keys.push_back(public_keys_i);
MS_LOG(INFO) << "Add public keys of client:" << fl_id << " successfully.";
}
}
fl::worker::FLWorker::GetInstance().set_public_keys_list(saved_remote_public_keys);
return true;
}
MS_REG_CPU_KERNEL(GetKeys, KernelAttr().AddOutputAttr(kNumberTypeFloat32), GetKeysKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_KEYS_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_KEYS_H_
#include <vector>
#include <string>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "fl/worker/fl_worker.h"
#include "fl/armour/secure_protocol/key_agreement.h"
namespace mindspore {
namespace kernel {
class GetKeysKernel : public CPUKernel {
public:
GetKeysKernel() = default;
~GetKeysKernel() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) override;
void Init(const CNodePtr &kernel_node) override;
void InitKernel(const CNodePtr &kernel_node) override;
private:
void BuildGetKeysReq(const std::shared_ptr<fl::FBBuilder> &fbb);
bool SavePublicKeyList(auto remote_public_key);
uint32_t rank_id_;
uint32_t server_num_;
uint32_t target_server_rank_;
std::string fl_id_;
std::shared_ptr<fl::FBBuilder> fbb_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_KEYS_H_

View File

@ -81,6 +81,11 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
}
MS_LOG(INFO) << " CipherInit::Init Success";
}
if (param.encrypt_type == mindspore::ps::kStablePWEncryptType) {
cipher_meta_storage_.RegisterStablePWClass();
MS_LOG(INFO) << "Register metadata for StablePWEncrypt is finished.";
}
return true;
}

View File

@ -34,7 +34,12 @@ bool CipherKeys::GetKeys(const size_t cur_iterator, const std::string &next_req_
}
// get clientlist from memory server.
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) {
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
} else {
cipher_init_->cipher_meta_storage_.GetStableClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
}
size_t cur_exchange_clients_num = client_public_keys.size();
std::string fl_id = get_exchange_keys_req->fl_id()->str();
@ -97,7 +102,13 @@ bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next
std::map<std::string, std::vector<std::vector<uint8_t>>> client_public_keys;
std::vector<std::string> client_list;
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list);
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) {
cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
} else {
cipher_init_->cipher_meta_storage_.GetStableClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys);
}
// step2: process new item data. and update new item data to memory server.
size_t cur_clients_num = client_list.size();
@ -119,8 +130,15 @@ bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next
return false;
}
bool retcode_key =
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
bool retcode_key;
if (encrypt_type == ps::kPWEncryptType) {
retcode_key =
cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
} else {
retcode_key =
cipher_init_->cipher_meta_storage_.UpdateStableClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req);
}
bool retcode_client =
cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id);
if (retcode_key && retcode_client) {
@ -134,7 +152,7 @@ bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next
cur_iterator);
return false;
}
}
} // namespace armour
void CipherKeys::BuildExchangeKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &fbb,
const schema::ResponseCode retcode, const std::string &reason,
@ -164,6 +182,7 @@ void CipherKeys::BuildGetKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &f
fbb->Finish(rsp_get_keys);
return;
}
const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientsKeys);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
@ -172,17 +191,27 @@ void CipherKeys::BuildGetKeysRsp(const std::shared_ptr<fl::server::FBBuilder> &f
std::string fl_id = iter->first;
fl::KeysPb keys_pb = iter->second;
auto fbs_fl_id = fbb->CreateString(fl_id);
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size());
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
std::vector<uint8_t> pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end());
auto fbs_pw_iv = fbb->CreateVector(pw_iv.data(), pw_iv.size());
std::vector<uint8_t> pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end());
auto fbs_pw_salt = fbb->CreateVector(pw_salt.data(), pw_salt.size());
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
public_keys_list.push_back(cur_public_key);
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) {
std::vector<uint8_t> cpk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<uint8_t> spk(keys_pb.key(1).begin(), keys_pb.key(1).end());
auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size());
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
public_keys_list.push_back(cur_public_key);
} else {
std::vector<uint8_t> spk(keys_pb.key(0).begin(), keys_pb.key(0).end());
auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size());
auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, 0, fbs_s_pk, fbs_pw_iv, fbs_pw_salt);
public_keys_list.push_back(cur_public_key);
}
}
auto remote_publickeys = fbb->CreateVector(public_keys_list);
auto fbs_next_req_time = fbb->CreateString(next_req_time);
schema::ReturnExchangeKeysBuilder rsp_builder(*(fbb.get()));

View File

@ -81,6 +81,26 @@ void CipherMetaStorage::GetClientKeysFromServer(
}
}
void CipherMetaStorage::GetStableClientKeysFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list) {
if (clients_keys_list == nullptr) {
MS_LOG(ERROR) << "Input clients_keys_list is nullptr";
return;
}
const fl::PBMetadata &clients_keys_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name);
const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys();
for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) {
std::string fl_id = iter->first;
fl::KeysPb keys_pb = iter->second;
std::vector<uint8_t> spk(keys_pb.key(0).begin(), keys_pb.key(0).end());
std::vector<std::vector<uint8_t>> cur_keys;
cur_keys.push_back(spk);
(void)clients_keys_list->emplace(std::pair<std::string, std::vector<std::vector<uint8_t>>>(fl_id, cur_keys));
}
}
void CipherMetaStorage::GetClientIVsFromServer(
const char *list_name, std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list) {
if (clients_ivs_list == nullptr) {
@ -171,26 +191,6 @@ void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string &
(void)sleep(time);
}
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
const std::vector<std::vector<uint8_t>> &cur_public_key) {
const size_t correct_size = 2;
if (cur_public_key.size() < correct_size) {
MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size();
return false;
}
// update new item to memory server.
fl::KeysPb keys;
keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end());
keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end());
fl::PairClientKeys pair_client_keys_pb;
pair_client_keys_pb.set_fl_id(fl_id);
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
fl::PBMetadata client_and_keys_pb;
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
return retcode;
}
bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
const schema::RequestExchangeKeys *exchange_keys_req) {
std::string fl_id = exchange_keys_req->fl_id()->str();
@ -257,6 +257,55 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name,
return retcode;
}
bool CipherMetaStorage::UpdateStableClientKeyToServer(const char *list_name,
const schema::RequestExchangeKeys *exchange_keys_req) {
std::string fl_id = exchange_keys_req->fl_id()->str();
auto fbs_spk = exchange_keys_req->s_pk();
if (fbs_spk == nullptr) {
MS_LOG(ERROR) << "Public key from exchange_keys_req is null";
return false;
}
size_t spk_len = fbs_spk->size();
// transform fbs_spk to a vector: public_key
std::vector<uint8_t> spk(spk_len);
bool ret_create_code_spk = CreateArray<uint8_t>(&spk, *fbs_spk);
if (!ret_create_code_spk) {
MS_LOG(ERROR) << "Create array for public keys failed";
return false;
}
auto fbs_pw_iv = exchange_keys_req->pw_iv();
std::vector<char> pw_iv;
if (fbs_pw_iv == nullptr) {
MS_LOG(WARNING) << "pw_iv in exchange_keys_req is nullptr";
} else {
pw_iv.assign(fbs_pw_iv->begin(), fbs_pw_iv->end());
}
auto fbs_pw_salt = exchange_keys_req->pw_salt();
std::vector<char> pw_salt;
if (fbs_pw_salt == nullptr) {
MS_LOG(WARNING) << "pw_salt in exchange_keys_req is nullptr";
} else {
pw_salt.assign(fbs_pw_salt->begin(), fbs_pw_salt->end());
}
// update new item to memory server.
fl::KeysPb keys;
keys.add_key()->assign(spk.begin(), spk.end());
keys.set_pw_iv(pw_iv.data(), pw_iv.size());
keys.set_pw_salt(pw_salt.data(), pw_salt.size());
fl::PairClientKeys pair_client_keys_pb;
pair_client_keys_pb.set_fl_id(fl_id);
pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys);
fl::PBMetadata client_and_keys_pb;
client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb);
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb);
return retcode;
}
bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise) {
// update new item to memory server.
fl::OneClientNoises noises_pb;
@ -325,5 +374,16 @@ void CipherMetaStorage::RegisterClass() {
fl::PBMetadata client_noises;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientNoises, client_noises);
}
void CipherMetaStorage::RegisterStablePWClass() {
fl::PBMetadata exchange_keys_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList,
exchange_keys_client_list);
fl::PBMetadata get_keys_client_list;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetKeysClientList,
get_keys_client_list);
fl::PBMetadata clients_keys;
fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys);
}
} // namespace armour
} // namespace mindspore

View File

@ -74,6 +74,7 @@ class CipherMetaStorage {
public:
// Register the shared value involved in the security aggregation.
void RegisterClass();
void RegisterStablePWClass();
// Register Prime.
void RegisterPrime(const char *list_name, const std::string &prime);
@ -87,17 +88,19 @@ class CipherMetaStorage {
// Get client keys from shared server.
void GetClientKeysFromServer(const char *list_name,
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list);
// Get stable secure aggregation's client key from shared server.
void GetStableClientKeysFromServer(const char *list_name,
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_keys_list);
void GetClientIVsFromServer(const char *list_name,
std::map<std::string, std::vector<std::vector<uint8_t>>> *clients_ivs_list);
// Get client noises from shared server.
bool GetClientNoisesFromServer(const char *list_name, std::vector<float> *cur_public_noise);
// Update client fl_id to shared server.
bool UpdateClientToServer(const char *list_name, const std::string &fl_id);
// Update client key to shared server.
bool UpdateClientKeyToServer(const char *list_name, const std::string &fl_id,
const std::vector<std::vector<uint8_t>> &cur_public_key);
// Update client key with signature to shared server.
bool UpdateClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req);
// Update stable secure aggregation's client key to shared server.
bool UpdateStableClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req);
// Update client noise to shared server.
bool UpdateClientNoiseToServer(const char *list_name, const std::vector<float> &cur_public_noise);
// Update client share to shared server.

View File

@ -66,7 +66,7 @@ int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) const {
}
int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) const {
if (pubKeyBytes == nullptr || len == nullptr || evpPrivKey == nullptr) {
if (evpPrivKey == nullptr) {
MS_LOG(ERROR) << "input pubKeyBytes invalid.";
return -1;
}

View File

@ -125,7 +125,10 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
return {};
}
if (metadata_.count(name) == 0) {
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
return {};
}
uint32_t stored_rank = router_->Find(name);
MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
if (local_rank_ == stored_rank) {
@ -257,42 +260,15 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
} else if (meta.has_prime()) {
metadata_[name] = meta;
} else if (meta.has_pair_client_keys()) {
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
auto &fl_id = meta.pair_client_keys().fl_id();
auto &client_keys = meta.pair_client_keys().client_keys();
// Check whether the new item already exists.
bool add_flag = true;
for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) {
if (fl_id == iter->first) {
add_flag = false;
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
break;
}
}
if (add_flag) {
client_keys_map[fl_id] = client_keys;
} else {
bool keys_update_succeed = UpdatePairClientKeys(name, meta);
if (!keys_update_succeed) {
MS_LOG(ERROR) << "Update pair_client_keys failed.";
return false;
}
} else if (meta.has_pair_client_shares()) {
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
auto &fl_id = meta.pair_client_shares().fl_id();
auto &client_shares = meta.pair_client_shares().client_shares();
// google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
// Check whether the new item already exists.
bool add_flag = true;
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) {
if (fl_id == iter->first) {
add_flag = false;
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
break;
}
}
if (add_flag) {
client_shares_map[fl_id] = client_shares;
} else {
bool shares_update_succeed = UpdatePairClientShares(name, meta);
if (!shares_update_succeed) {
MS_LOG(ERROR) << "Update pair_client_shares failed.";
return false;
}
} else if (meta.has_one_client_noises()) {
@ -310,6 +286,51 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
}
return true;
}
bool DistributedMetadataStore::UpdatePairClientKeys(const std::string &name, const PBMetadata &meta) {
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
auto &fl_id = meta.pair_client_keys().fl_id();
auto &client_keys = meta.pair_client_keys().client_keys();
// Check whether the new item already exists.
bool add_flag = true;
for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) {
if (fl_id == iter->first) {
add_flag = false;
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
break;
}
}
if (add_flag) {
client_keys_map[fl_id] = client_keys;
return true;
} else {
return false;
}
}
bool DistributedMetadataStore::UpdatePairClientShares(const std::string &name, const PBMetadata &meta) {
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
auto &fl_id = meta.pair_client_shares().fl_id();
auto &client_shares = meta.pair_client_shares().client_shares();
// google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter;
// Check whether the new item already exists.
bool add_flag = true;
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) {
if (fl_id == iter->first) {
add_flag = false;
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
break;
}
}
if (add_flag) {
client_shares_map[fl_id] = client_shares;
return true;
} else {
return false;
}
}
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -88,6 +88,12 @@ class DistributedMetadataStore {
// Do updating metadata in the server where the metadata for the name is stored.
bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta);
// Update client keys stored in server
bool UpdatePairClientKeys(const std::string &name, const PBMetadata &meta);
// Update client shares stored in server
bool UpdatePairClientShares(const std::string &name, const PBMetadata &meta);
// Members for the communication between servers.
std::shared_ptr<ps::core::ServerNode> server_node_;
std::shared_ptr<ps::core::TcpCommunicator> communicator_;

View File

@ -19,6 +19,7 @@
#include <vector>
#include <utility>
#include "fl/worker/fl_worker.h"
#include "fl/armour/secure_protocol/key_agreement.h"
#include "utils/ms_exception.h"
namespace mindspore {
@ -173,8 +174,26 @@ uint64_t FLWorker::fl_iteration_num() const { return iteration_num_.load(); }
void FLWorker::set_data_size(int data_size) { data_size_ = data_size; }
void FLWorker::set_secret_pk(armour::PrivateKey *secret_pk) { secret_pk_ = secret_pk; }
void FLWorker::set_pw_salt(std::vector<uint8_t> pw_salt) { pw_salt_ = pw_salt; }
void FLWorker::set_pw_iv(std::vector<uint8_t> pw_iv) { pw_iv_ = pw_iv; }
void FLWorker::set_public_keys_list(std::vector<EncryptPublicKeys> public_keys_list) {
public_keys_list_ = public_keys_list;
}
int FLWorker::data_size() const { return data_size_; }
armour::PrivateKey *FLWorker::secret_pk() const { return secret_pk_; }
std::vector<uint8_t> FLWorker::pw_salt() const { return pw_salt_; }
std::vector<uint8_t> FLWorker::pw_iv() const { return pw_iv_; }
std::vector<EncryptPublicKeys> FLWorker::public_keys_list() const { return public_keys_list_; }
std::string FLWorker::fl_name() const { return ps::kServerModeFL; }
std::string FLWorker::fl_id() const { return std::to_string(rank_id_); }

View File

@ -23,11 +23,19 @@
#include "proto/comm.pb.h"
#include "schema/fl_job_generated.h"
#include "schema/cipher_generated.h"
#include "fl/armour/secure_protocol/key_agreement.h"
#include "ps/ps_context.h"
#include "ps/core/worker_node.h"
#include "ps/core/cluster_metadata.h"
#include "ps/core/communicator/tcp_communicator.h"
struct EncryptPublicKeys {
std::string flID;
std::vector<uint8_t> publicKey;
std::vector<uint8_t> pwIV;
std::vector<uint8_t> pwSalt;
};
namespace mindspore {
namespace fl {
using FBBuilder = flatbuffers::FlatBufferBuilder;
@ -88,6 +96,18 @@ class FLWorker {
void set_data_size(int data_size);
int data_size() const;
void set_secret_pk(armour::PrivateKey *secret_pk);
armour::PrivateKey *secret_pk() const;
void set_pw_salt(std::vector<uint8_t> pw_salt);
std::vector<uint8_t> pw_salt() const;
void set_pw_iv(std::vector<uint8_t> pw_iv);
std::vector<uint8_t> pw_iv() const;
void set_public_keys_list(std::vector<EncryptPublicKeys> public_keys_list);
std::vector<EncryptPublicKeys> public_keys_list() const;
std::string fl_name() const;
std::string fl_id() const;
@ -105,7 +125,11 @@ class FLWorker {
worker_step_num_per_iteration_(1),
server_iteration_state_(IterationState::kCompleted),
worker_iteration_state_(IterationState::kCompleted),
safemode_(false) {}
safemode_(false),
secret_pk_(nullptr),
pw_salt_({}),
pw_iv_({}),
public_keys_list_({}) {}
~FLWorker() = default;
FLWorker(const FLWorker &) = delete;
FLWorker &operator=(const FLWorker &) = delete;
@ -152,6 +176,18 @@ class FLWorker {
// The flag that represents whether worker is in safemode, which is decided by both worker and server iteration state.
std::atomic_bool safemode_;
// The private key used for computing the pairwise encryption's secret.
armour::PrivateKey *secret_pk_;
// The salt value used for generate pairwise noise.
std::vector<uint8_t> pw_salt_;
// The initialization vector value used for generate pairwise noise.
std::vector<uint8_t> pw_iv_;
// The public keys used for computing the pairwise encryption's secret.
std::vector<EncryptPublicKeys> public_keys_list_;
};
} // namespace worker
} // namespace fl

View File

@ -890,6 +890,11 @@ bool StartServerAction(const ResourcePtr &res) {
rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
}
if (encrypt_type == ps::kStablePWEncryptType) {
MS_LOG(INFO) << "Add stable secure aggregation rounds.";
rounds_config.push_back({"exchangeKeys", true, cipher_time_window, true, exchange_keys_threshold});
rounds_config.push_back({"getKeys", true, cipher_time_window, true, get_keys_threshold});
}
#endif
fl::server::CipherConfig cipher_config = {
share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,

View File

@ -56,7 +56,9 @@ enum class TcpUserCommand {
kNewInstance,
kQueryInstance,
kEnableFLS,
kDisableFLS
kDisableFLS,
kExchangeKeys,
kGetKeys
};
// CommunicatorBase is used to receive request and send response for server.

View File

@ -53,6 +53,8 @@ const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
{TcpUserCommand::kProceedToNextIter, "proceedToNextIter"},
{TcpUserCommand::kEndLastIter, "endLastIter"},
{TcpUserCommand::kStartFLJob, "startFLJob"},
{TcpUserCommand::kExchangeKeys, "exchangeKeys"},
{TcpUserCommand::kGetKeys, "getKeys"},
{TcpUserCommand::kUpdateModel, "updateModel"},
{TcpUserCommand::kGetModel, "getModel"},
{TcpUserCommand::kPushMetrics, "pushMetrics"},

View File

@ -199,9 +199,10 @@ void PSContext::set_server_mode(const std::string &server_mode) {
const std::string &PSContext::server_mode() const { return server_mode_; }
void PSContext::set_encrypt_type(const std::string &encrypt_type) {
if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) {
if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType &&
encrypt_type != kStablePWEncryptType) {
MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
<< kDPEncryptType << " or " << kPWEncryptType;
<< kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType;
return;
}
encrypt_type_ = encrypt_type;

View File

@ -37,6 +37,7 @@ constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
constexpr char kDPEncryptType[] = "DP_ENCRYPT";
constexpr char kPWEncryptType[] = "PW_ENCRYPT";
constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT";
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
// Use binary data to represent federated learning server's context so that we can judge which round resets the

View File

@ -985,11 +985,12 @@ def set_fl_context(**kwargs):
client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01.
dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is
suggested to be 0.5~2. Default: 1.0.
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT' or
'PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy
protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If
'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing.
Default: 'NOT_ENCRYPT'.
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT',
'PW_ENCRYPT' or 'STABLE_PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied
for clients and the privacy protection effect would be determined by dp_eps, dp_delta and dp_norm_clip
as described above. If 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients'
model from stealing in cross-device scenario. If 'STABLE_PW_ENCRYPT', pairwise secure aggregation would
be applied to protect clients' model from stealing in cross-silo scenario. Default: 'NOT_ENCRYPT'.
config_file_path (string): Configuration file path used by recovery. Default: ''.
scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.

View File

@ -92,7 +92,7 @@ from ._quant_ops import *
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
ConfusionMatrix, PopulationCount, UpdateState, Load,
CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight,
PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc)
PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc, ExchangeKeys, GetKeys)
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col,

View File

@ -885,6 +885,40 @@ class GetModel(PrimitiveWithInfer):
return mstype.float32
class ExchangeKeys(PrimitiveWithInfer):
"""
Exchange pairwise public keys for federated learning worker.
"""
@prim_attr_register
def __init__(self):
self.add_prim_attr("primitive_target", "CPU")
self.add_prim_attr('side_effect_mem', True)
self.init_prim_io_names(inputs=[], outputs=["result"])
def infer_shape(self):
return [1]
def infer_dtype(self):
return mstype.float32
class GetKeys(PrimitiveWithInfer):
"""
Get pairwise public keys for federated learning worker.
"""
@prim_attr_register
def __init__(self):
self.add_prim_attr("primitive_target", "CPU")
self.add_prim_attr('side_effect_mem', True)
self.init_prim_io_names(inputs=[], outputs=["result"])
def infer_shape(self):
return [1]
def infer_dtype(self):
return mstype.float32
class identity(Primitive):
"""
Makes a identify primitive, used for pynative mode.

View File

@ -31,6 +31,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--local_worker_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
args, _ = parser.parse_known_args()
device_target = args.device_target
@ -47,6 +48,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
local_worker_num = args.local_worker_num
config_file_path = args.config_file_path
dataset_path = args.dataset_path
encrypt_type = args.encrypt_type
if local_worker_num == -1:
local_worker_num = worker_num
@ -73,6 +75,7 @@ for i in range(local_worker_num):
cmd_worker += " --client_learning_rate=" + str(client_learning_rate)
cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration)
cmd_worker += " --dataset_path=" + str(dataset_path)
cmd_worker += " --encrypt_type=" + str(encrypt_type)
cmd_worker += " --user_id=" + str(i)
cmd_worker += " > worker.log 2>&1 &"

View File

@ -58,6 +58,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT")
parser.add_argument("--cipher_time_window", type=int, default=300000)
parser.add_argument("--dataset_path", type=str, default="")
# The user_id is used to set each worker's dataset path.
parser.add_argument("--user_id", type=str, default="0")
@ -87,6 +88,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
encrypt_type = args.encrypt_type
cipher_time_window = args.cipher_time_window
dataset_path = args.dataset_path
user_id = args.user_id
@ -111,7 +113,8 @@ ctx = {
"worker_step_num_per_iteration": worker_step_num_per_iteration,
"scheduler_manage_port": scheduler_manage_port,
"config_file_path": config_file_path,
"encrypt_type": encrypt_type
"encrypt_type": encrypt_type,
"cipher_time_window": cipher_time_window,
}
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
@ -285,6 +288,24 @@ class StartFLJob(nn.Cell):
return self.start_fl_job()
class ExchangeKeys(nn.Cell):
def __init__(self):
super(ExchangeKeys, self).__init__()
self.exchange_keys = P.ExchangeKeys()
def construct(self):
return self.exchange_keys()
class GetKeys(nn.Cell):
def __init__(self):
super(GetKeys, self).__init__()
self.get_keys = P.GetKeys()
def construct(self):
return self.get_keys()
class UpdateAndGetModel(nn.Cell):
def __init__(self, weights):
super(UpdateAndGetModel, self).__init__()
@ -328,6 +349,11 @@ def train():
if context.get_fl_context("ms_role") == "MS_WORKER":
start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size)
start_fl_job()
if encrypt_type == "STABLE_PW_ENCRYPT":
exchange_keys = ExchangeKeys()
exchange_keys()
get_keys = GetKeys()
get_keys()
for _ in range(epoch):
print("step is ", epoch, flush=True)