diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt index 1438c943283..d8c4167993b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -60,6 +60,8 @@ if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/start_fl_job_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/update_model_kernel.cc") list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/push_metrics_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/get_keys_kernel.cc") + list(REMOVE_ITEM CPU_SRC_LIST "cpu/fl/exchange_keys_kernel.cc") endif() if(ENABLE_SECURITY) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/exchange_keys_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/exchange_keys_kernel.cc new file mode 100644 index 00000000000..1df849f2388 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/exchange_keys_kernel.cc @@ -0,0 +1,168 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/fl/exchange_keys_kernel.h" + +namespace mindspore { +namespace kernel { +constexpr int iv_vec_len = 16; +constexpr int salt_len = 32; + +bool ExchangeKeysKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &) { + MS_LOG(INFO) << "Launching client ExchangeKeysKernel"; + if (!BuildExchangeKeysReq(fbb_)) { + MS_LOG(EXCEPTION) << "Building request for ExchangeKeys failed."; + return false; + } + + std::shared_ptr> exchange_keys_rsp_msg = nullptr; + if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(), fbb_->GetSize(), + ps::core::TcpUserCommand::kExchangeKeys, + &exchange_keys_rsp_msg)) { + MS_LOG(EXCEPTION) << "Sending request for ExchangeKeys to server " << target_server_rank_ << " failed."; + return false; + } + if (exchange_keys_rsp_msg == nullptr) { + MS_LOG(EXCEPTION) << "Received message pointer is nullptr."; + return false; + } + flatbuffers::Verifier verifier(exchange_keys_rsp_msg->data(), exchange_keys_rsp_msg->size()); + if (!verifier.VerifyBuffer()) { + MS_LOG(EXCEPTION) << "The schema of ResponseExchangeKeys is invalid."; + return false; + } + + const schema::ResponseExchangeKeys *exchange_keys_rsp = + flatbuffers::GetRoot(exchange_keys_rsp_msg->data()); + MS_EXCEPTION_IF_NULL(exchange_keys_rsp); + auto response_code = exchange_keys_rsp->retcode(); + if ((response_code != schema::ResponseCode_SUCCEED) && (response_code != schema::ResponseCode_OutOfTime)) { + MS_LOG(EXCEPTION) << "Launching exchange keys job for worker failed. Reason: " << exchange_keys_rsp->reason(); + } + + MS_LOG(INFO) << "Exchange keys successfully."; + return true; +} + +void ExchangeKeysKernel::Init(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + fl_id_ = fl::worker::FLWorker::GetInstance().fl_id(); + server_num_ = fl::worker::FLWorker::GetInstance().server_num(); + rank_id_ = fl::worker::FLWorker::GetInstance().rank_id(); + if (rank_id_ == UINT32_MAX) { + MS_LOG(EXCEPTION) << "Federated worker is not initialized yet."; + return; + } + + if (server_num_ <= 0) { + MS_LOG(EXCEPTION) << "Server number should be larger than 0, but got: " << server_num_; + return; + } + target_server_rank_ = rank_id_ % server_num_; + + MS_LOG(INFO) << "Initializing ExchangeKeys kernel" + << ", fl_id: " << fl_id_ << ". Request will be sent to server " << target_server_rank_; + + fbb_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(fbb_); + input_size_list_.push_back(sizeof(int)); + output_size_list_.push_back(sizeof(float)); + MS_LOG(INFO) << "Initialize ExchangeKeys kernel successfully."; +} + +void ExchangeKeysKernel::InitKernel(const CNodePtr &kernel_node) { return; } + +bool ExchangeKeysKernel::BuildExchangeKeysReq(const std::shared_ptr &fbb) { + MS_EXCEPTION_IF_NULL(fbb); + // generate initialization vector value used for generate pairwise noise + std::vector pw_iv_(iv_vec_len); + std::vector pw_salt_(salt_len); + auto ret = RAND_bytes(pw_iv_.data(), iv_vec_len); + if (ret != 1) { + MS_LOG(ERROR) << "RAND_bytes error, failed to init pw_iv_."; + return false; + } + // generate salt value used for generate pairwise noise + ret = RAND_bytes(pw_salt_.data(), salt_len); + if (ret != 1) { + MS_LOG(ERROR) << "RAND_bytes error, failed to init pw_salt_."; + return false; + } + + // save pw_salt and pw_iv at local + fl::worker::FLWorker::GetInstance().set_pw_salt(pw_salt_); + fl::worker::FLWorker::GetInstance().set_pw_iv(pw_iv_); + + // get public key bytes + std::vector pubkey_bytes = GetPubicKeyBytes(); + if (pubkey_bytes.size() == 0) { + MS_LOG(EXCEPTION) << "Get public key failed."; + return false; + } + + // build data which will be send to server + int iter = fl::worker::FLWorker::GetInstance().fl_iteration_num(); + auto fbs_fl_id = fbb->CreateString(fl_id_); + auto fbs_public_key = fbb->CreateVector(pubkey_bytes.data(), pubkey_bytes.size()); + auto fbs_pw_iv = fbb->CreateVector(pw_iv_.data(), iv_vec_len); + auto fbs_pw_salt = fbb->CreateVector(pw_salt_.data(), salt_len); + schema::RequestExchangeKeysBuilder req_exchange_key_builder(*(fbb.get())); + req_exchange_key_builder.add_fl_id(fbs_fl_id); + req_exchange_key_builder.add_s_pk(fbs_public_key); + req_exchange_key_builder.add_iteration(iter); + req_exchange_key_builder.add_pw_iv(fbs_pw_iv); + req_exchange_key_builder.add_pw_salt(fbs_pw_salt); + auto req_fl_job = req_exchange_key_builder.Finish(); + fbb->Finish(req_fl_job); + MS_LOG(INFO) << "BuildExchangeKeysReq successfully."; + return true; +} + +std::vector ExchangeKeysKernel::GetPubicKeyBytes() { + // generate private key of secret + armour::PrivateKey *sPriKeyPtr = armour::KeyAgreement::GeneratePrivKey(); + fl::worker::FLWorker::GetInstance().set_secret_pk(sPriKeyPtr); + + // get public bytes length + size_t pubLen; + uint8_t *secret_pubkey_ptr = NULL; + auto ret = sPriKeyPtr->GetPublicBytes(&pubLen, secret_pubkey_ptr); + if (ret != 0 || pubLen == 0) { + MS_LOG(ERROR) << "GetPublicBytes error, failed to get public_key bytes length."; + return {}; + } + // pubLen has been updated, now get public_key bytes + secret_pubkey_ptr = reinterpret_cast(malloc(pubLen)); + ret = sPriKeyPtr->GetPublicBytes(&pubLen, secret_pubkey_ptr); + if (ret != 0) { + free(secret_pubkey_ptr); + MS_LOG(ERROR) << "GetPublicBytes error, failed to get public_key bytes."; + return {}; + } + + // transform key buffer to uint8_t vector + std::vector pubkey_bytes(pubLen); + for (int i = 0; i < SizeToInt(pubLen); i++) { + pubkey_bytes[i] = secret_pubkey_ptr[i]; + } + free(secret_pubkey_ptr); + return pubkey_bytes; +} + +MS_REG_CPU_KERNEL(ExchangeKeys, KernelAttr().AddOutputAttr(kNumberTypeFloat32), ExchangeKeysKernel); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/exchange_keys_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/exchange_keys_kernel.h new file mode 100644 index 00000000000..be47b447f86 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/exchange_keys_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_EXCHANGE_KEYS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_EXCHANGE_KEYS_H_ + +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "fl/worker/fl_worker.h" +#include "fl/armour/secure_protocol/key_agreement.h" + +namespace mindspore { +namespace kernel { +class ExchangeKeysKernel : public CPUKernel { + public: + ExchangeKeysKernel() = default; + ~ExchangeKeysKernel() override = default; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &) override; + + void Init(const CNodePtr &kernel_node) override; + + void InitKernel(const CNodePtr &kernel_node) override; + + private: + bool BuildExchangeKeysReq(const std::shared_ptr &fbb); + std::vector GetPubicKeyBytes(); + + uint32_t rank_id_; + uint32_t server_num_; + uint32_t target_server_rank_; + std::string fl_id_; + std::shared_ptr fbb_; + armour::PrivateKey *secret_prikey_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_EXCHANGE_KEYS_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/get_keys_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/get_keys_kernel.cc new file mode 100644 index 00000000000..6fa72b8eff9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/get_keys_kernel.cc @@ -0,0 +1,141 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/fl/get_keys_kernel.h" + +namespace mindspore { +namespace kernel { +bool GetKeysKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &) { + MS_LOG(INFO) << "Launching client GetKeysKernel"; + BuildGetKeysReq(fbb_); + + std::shared_ptr> get_keys_rsp_msg = nullptr; + if (!fl::worker::FLWorker::GetInstance().SendToServer(target_server_rank_, fbb_->GetBufferPointer(), fbb_->GetSize(), + ps::core::TcpUserCommand::kGetKeys, &get_keys_rsp_msg)) { + MS_LOG(EXCEPTION) << "Sending request for GetKeys to server " << target_server_rank_ << " failed."; + return false; + } + if (get_keys_rsp_msg == nullptr) { + MS_LOG(EXCEPTION) << "Received message pointer is nullptr."; + return false; + } + flatbuffers::Verifier verifier(get_keys_rsp_msg->data(), get_keys_rsp_msg->size()); + if (!verifier.VerifyBuffer()) { + MS_LOG(EXCEPTION) << "The schema of ResponseGetKeys is invalid."; + return false; + } + + const schema::ReturnExchangeKeys *get_keys_rsp = + flatbuffers::GetRoot(get_keys_rsp_msg->data()); + MS_EXCEPTION_IF_NULL(get_keys_rsp); + auto response_code = get_keys_rsp->retcode(); + if ((response_code != schema::ResponseCode_SUCCEED) && (response_code != schema::ResponseCode_OutOfTime)) { + MS_LOG(EXCEPTION) << "Launching get keys job for worker failed. response_code: " << response_code; + } + + bool save_keys_succeed = SavePublicKeyList(get_keys_rsp->remote_publickeys()); + if (!save_keys_succeed) { + MS_LOG(EXCEPTION) << "Save received remote keys failed."; + return false; + } + + MS_LOG(INFO) << "Get keys successfully."; + return true; +} + +void GetKeysKernel::Init(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + fl_id_ = fl::worker::FLWorker::GetInstance().fl_id(); + server_num_ = fl::worker::FLWorker::GetInstance().server_num(); + rank_id_ = fl::worker::FLWorker::GetInstance().rank_id(); + if (rank_id_ == UINT32_MAX) { + MS_LOG(EXCEPTION) << "Federated worker is not initialized yet."; + return; + } + if (server_num_ <= 0) { + MS_LOG(EXCEPTION) << "Server number should be larger than 0, but got: " << server_num_; + return; + } + target_server_rank_ = rank_id_ % server_num_; + + MS_LOG(INFO) << "Initializing GetKeys kernel" + << ", fl_id: " << fl_id_ << ". Request will be sent to server " << target_server_rank_; + + fbb_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(fbb_); + input_size_list_.push_back(sizeof(int)); + output_size_list_.push_back(sizeof(float)); + MS_LOG(INFO) << "Initialize GetKeys kernel successfully."; +} + +void GetKeysKernel::InitKernel(const CNodePtr &kernel_node) { return; } + +void GetKeysKernel::BuildGetKeysReq(const std::shared_ptr &fbb) { + MS_EXCEPTION_IF_NULL(fbb); + int iter = fl::worker::FLWorker::GetInstance().fl_iteration_num(); + auto fbs_fl_id = fbb->CreateString(fl_id_); + schema::GetExchangeKeysBuilder get_keys_builder(*(fbb.get())); + get_keys_builder.add_fl_id(fbs_fl_id); + get_keys_builder.add_iteration(iter); + auto req_fl_job = get_keys_builder.Finish(); + fbb->Finish(req_fl_job); + MS_LOG(INFO) << "BuildGetKeysReq successfully."; +} + +bool GetKeysKernel::SavePublicKeyList(auto remote_public_key) { + if (remote_public_key == nullptr) { + MS_LOG(EXCEPTION) << "Input remote_pubic_key is nullptr."; + } + + int client_num = remote_public_key->size(); + if (client_num <= 0) { + MS_LOG(EXCEPTION) << "Received client keys length is <= 0, please check it!"; + return false; + } + + // save client keys list + std::vector saved_remote_public_keys; + for (auto iter = remote_public_key->begin(); iter != remote_public_key->end(); ++iter) { + std::string fl_id = iter->fl_id()->str(); + auto fbs_spk = iter->s_pk(); + auto fbs_pw_iv = iter->pw_iv(); + auto fbs_pw_salt = iter->pw_salt(); + if (fbs_spk == nullptr || fbs_pw_iv == nullptr || fbs_pw_salt == nullptr) { + MS_LOG(WARNING) << "public key, pw_iv or pw_salt in remote_publickeys is nullptr."; + } else { + std::vector spk_vector; + std::vector pw_iv_vector; + std::vector pw_salt_vector; + spk_vector.assign(fbs_spk->begin(), fbs_spk->end()); + pw_iv_vector.assign(fbs_pw_iv->begin(), fbs_pw_iv->end()); + pw_salt_vector.assign(fbs_pw_salt->begin(), fbs_pw_salt->end()); + EncryptPublicKeys public_keys_i; + public_keys_i.flID = fl_id; + public_keys_i.publicKey = spk_vector; + public_keys_i.pwIV = pw_iv_vector; + public_keys_i.pwSalt = pw_salt_vector; + saved_remote_public_keys.push_back(public_keys_i); + MS_LOG(INFO) << "Add public keys of client:" << fl_id << " successfully."; + } + } + fl::worker::FLWorker::GetInstance().set_public_keys_list(saved_remote_public_keys); + return true; +} + +MS_REG_CPU_KERNEL(GetKeys, KernelAttr().AddOutputAttr(kNumberTypeFloat32), GetKeysKernel); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/get_keys_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/get_keys_kernel.h new file mode 100644 index 00000000000..2d0cbc63e8c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/get_keys_kernel.h @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_KEYS_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_KEYS_H_ + +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "fl/worker/fl_worker.h" +#include "fl/armour/secure_protocol/key_agreement.h" + +namespace mindspore { +namespace kernel { +class GetKeysKernel : public CPUKernel { + public: + GetKeysKernel() = default; + ~GetKeysKernel() override = default; + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &) override; + + void Init(const CNodePtr &kernel_node) override; + + void InitKernel(const CNodePtr &kernel_node) override; + + private: + void BuildGetKeysReq(const std::shared_ptr &fbb); + bool SavePublicKeyList(auto remote_public_key); + + uint32_t rank_id_; + uint32_t server_num_; + uint32_t target_server_rank_; + std::string fl_id_; + std::shared_ptr fbb_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_FL_GET_KEYS_H_ diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc index 163090e5277..aedac9b9900 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc @@ -81,6 +81,11 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size } MS_LOG(INFO) << " CipherInit::Init Success"; } + + if (param.encrypt_type == mindspore::ps::kStablePWEncryptType) { + cipher_meta_storage_.RegisterStablePWClass(); + MS_LOG(INFO) << "Register metadata for StablePWEncrypt is finished."; + } return true; } diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc index f38e905f6c2..18885d75da3 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_keys.cc @@ -34,7 +34,12 @@ bool CipherKeys::GetKeys(const size_t cur_iterator, const std::string &next_req_ } // get clientlist from memory server. std::map>> client_public_keys; - cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys); + std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); + if (encrypt_type == ps::kPWEncryptType) { + cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys); + } else { + cipher_init_->cipher_meta_storage_.GetStableClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys); + } size_t cur_exchange_clients_num = client_public_keys.size(); std::string fl_id = get_exchange_keys_req->fl_id()->str(); @@ -97,7 +102,13 @@ bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next std::map>> client_public_keys; std::vector client_list; cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxExChangeKeysClientList, &client_list); - cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys); + + std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); + if (encrypt_type == ps::kPWEncryptType) { + cipher_init_->cipher_meta_storage_.GetClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys); + } else { + cipher_init_->cipher_meta_storage_.GetStableClientKeysFromServer(fl::server::kCtxClientsKeys, &client_public_keys); + } // step2: process new item data. and update new item data to memory server. size_t cur_clients_num = client_list.size(); @@ -119,8 +130,15 @@ bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next return false; } - bool retcode_key = - cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req); + bool retcode_key; + if (encrypt_type == ps::kPWEncryptType) { + retcode_key = + cipher_init_->cipher_meta_storage_.UpdateClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req); + } else { + retcode_key = + cipher_init_->cipher_meta_storage_.UpdateStableClientKeyToServer(fl::server::kCtxClientsKeys, exchange_keys_req); + } + bool retcode_client = cipher_init_->cipher_meta_storage_.UpdateClientToServer(fl::server::kCtxExChangeKeysClientList, fl_id); if (retcode_key && retcode_client) { @@ -134,7 +152,7 @@ bool CipherKeys::ExchangeKeys(const size_t cur_iterator, const std::string &next cur_iterator); return false; } -} +} // namespace armour void CipherKeys::BuildExchangeKeysRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, @@ -164,6 +182,7 @@ void CipherKeys::BuildGetKeysRsp(const std::shared_ptr &f fbb->Finish(rsp_get_keys); return; } + const fl::PBMetadata &clients_keys_pb_out = fl::server::DistributedMetadataStore::GetInstance().GetMetadata(fl::server::kCtxClientsKeys); const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys(); @@ -172,17 +191,27 @@ void CipherKeys::BuildGetKeysRsp(const std::shared_ptr &f std::string fl_id = iter->first; fl::KeysPb keys_pb = iter->second; auto fbs_fl_id = fbb->CreateString(fl_id); - std::vector cpk(keys_pb.key(0).begin(), keys_pb.key(0).end()); - std::vector spk(keys_pb.key(1).begin(), keys_pb.key(1).end()); - auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size()); - auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size()); std::vector pw_iv(keys_pb.pw_iv().begin(), keys_pb.pw_iv().end()); auto fbs_pw_iv = fbb->CreateVector(pw_iv.data(), pw_iv.size()); std::vector pw_salt(keys_pb.pw_salt().begin(), keys_pb.pw_salt().end()); auto fbs_pw_salt = fbb->CreateVector(pw_salt.data(), pw_salt.size()); - auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt); - public_keys_list.push_back(cur_public_key); + + std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); + if (encrypt_type == ps::kPWEncryptType) { + std::vector cpk(keys_pb.key(0).begin(), keys_pb.key(0).end()); + std::vector spk(keys_pb.key(1).begin(), keys_pb.key(1).end()); + auto fbs_c_pk = fbb->CreateVector(cpk.data(), cpk.size()); + auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size()); + auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, fbs_c_pk, fbs_s_pk, fbs_pw_iv, fbs_pw_salt); + public_keys_list.push_back(cur_public_key); + } else { + std::vector spk(keys_pb.key(0).begin(), keys_pb.key(0).end()); + auto fbs_s_pk = fbb->CreateVector(spk.data(), spk.size()); + auto cur_public_key = schema::CreateClientPublicKeys(*fbb, fbs_fl_id, 0, fbs_s_pk, fbs_pw_iv, fbs_pw_salt); + public_keys_list.push_back(cur_public_key); + } } + auto remote_publickeys = fbb->CreateVector(public_keys_list); auto fbs_next_req_time = fbb->CreateString(next_req_time); schema::ReturnExchangeKeysBuilder rsp_builder(*(fbb.get())); diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc index f6857937482..8b1b4f82d7a 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.cc @@ -81,6 +81,26 @@ void CipherMetaStorage::GetClientKeysFromServer( } } +void CipherMetaStorage::GetStableClientKeysFromServer( + const char *list_name, std::map>> *clients_keys_list) { + if (clients_keys_list == nullptr) { + MS_LOG(ERROR) << "Input clients_keys_list is nullptr"; + return; + } + const fl::PBMetadata &clients_keys_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(list_name); + const fl::ClientKeys &clients_keys_pb = clients_keys_pb_out.client_keys(); + + for (auto iter = clients_keys_pb.client_keys().begin(); iter != clients_keys_pb.client_keys().end(); ++iter) { + std::string fl_id = iter->first; + fl::KeysPb keys_pb = iter->second; + std::vector spk(keys_pb.key(0).begin(), keys_pb.key(0).end()); + std::vector> cur_keys; + cur_keys.push_back(spk); + (void)clients_keys_list->emplace(std::pair>>(fl_id, cur_keys)); + } +} + void CipherMetaStorage::GetClientIVsFromServer( const char *list_name, std::map>> *clients_ivs_list) { if (clients_ivs_list == nullptr) { @@ -171,26 +191,6 @@ void CipherMetaStorage::RegisterPrime(const char *list_name, const std::string & (void)sleep(time); } -bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const std::string &fl_id, - const std::vector> &cur_public_key) { - const size_t correct_size = 2; - if (cur_public_key.size() < correct_size) { - MS_LOG(ERROR) << "cur_public_key's size must is 2. actual size is " << cur_public_key.size(); - return false; - } - // update new item to memory server. - fl::KeysPb keys; - keys.add_key()->assign(cur_public_key[0].begin(), cur_public_key[0].end()); - keys.add_key()->assign(cur_public_key[1].begin(), cur_public_key[1].end()); - fl::PairClientKeys pair_client_keys_pb; - pair_client_keys_pb.set_fl_id(fl_id); - pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys); - fl::PBMetadata client_and_keys_pb; - client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb); - bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb); - return retcode; -} - bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req) { std::string fl_id = exchange_keys_req->fl_id()->str(); @@ -257,6 +257,55 @@ bool CipherMetaStorage::UpdateClientKeyToServer(const char *list_name, return retcode; } +bool CipherMetaStorage::UpdateStableClientKeyToServer(const char *list_name, + const schema::RequestExchangeKeys *exchange_keys_req) { + std::string fl_id = exchange_keys_req->fl_id()->str(); + auto fbs_spk = exchange_keys_req->s_pk(); + if (fbs_spk == nullptr) { + MS_LOG(ERROR) << "Public key from exchange_keys_req is null"; + return false; + } + + size_t spk_len = fbs_spk->size(); + + // transform fbs_spk to a vector: public_key + std::vector spk(spk_len); + bool ret_create_code_spk = CreateArray(&spk, *fbs_spk); + if (!ret_create_code_spk) { + MS_LOG(ERROR) << "Create array for public keys failed"; + return false; + } + + auto fbs_pw_iv = exchange_keys_req->pw_iv(); + std::vector pw_iv; + if (fbs_pw_iv == nullptr) { + MS_LOG(WARNING) << "pw_iv in exchange_keys_req is nullptr"; + } else { + pw_iv.assign(fbs_pw_iv->begin(), fbs_pw_iv->end()); + } + + auto fbs_pw_salt = exchange_keys_req->pw_salt(); + std::vector pw_salt; + if (fbs_pw_salt == nullptr) { + MS_LOG(WARNING) << "pw_salt in exchange_keys_req is nullptr"; + } else { + pw_salt.assign(fbs_pw_salt->begin(), fbs_pw_salt->end()); + } + + // update new item to memory server. + fl::KeysPb keys; + keys.add_key()->assign(spk.begin(), spk.end()); + keys.set_pw_iv(pw_iv.data(), pw_iv.size()); + keys.set_pw_salt(pw_salt.data(), pw_salt.size()); + fl::PairClientKeys pair_client_keys_pb; + pair_client_keys_pb.set_fl_id(fl_id); + pair_client_keys_pb.mutable_client_keys()->MergeFrom(keys); + fl::PBMetadata client_and_keys_pb; + client_and_keys_pb.mutable_pair_client_keys()->MergeFrom(pair_client_keys_pb); + bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(list_name, client_and_keys_pb); + return retcode; +} + bool CipherMetaStorage::UpdateClientNoiseToServer(const char *list_name, const std::vector &cur_public_noise) { // update new item to memory server. fl::OneClientNoises noises_pb; @@ -325,5 +374,16 @@ void CipherMetaStorage::RegisterClass() { fl::PBMetadata client_noises; fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientNoises, client_noises); } + +void CipherMetaStorage::RegisterStablePWClass() { + fl::PBMetadata exchange_keys_client_list; + fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxExChangeKeysClientList, + exchange_keys_client_list); + fl::PBMetadata get_keys_client_list; + fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxGetKeysClientList, + get_keys_client_list); + fl::PBMetadata clients_keys; + fl::server::DistributedMetadataStore::GetInstance().RegisterMetadata(fl::server::kCtxClientsKeys, clients_keys); +} } // namespace armour } // namespace mindspore diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h index c9c4130fbad..7d1811b16d9 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_meta_storage.h @@ -74,6 +74,7 @@ class CipherMetaStorage { public: // Register the shared value involved in the security aggregation. void RegisterClass(); + void RegisterStablePWClass(); // Register Prime. void RegisterPrime(const char *list_name, const std::string &prime); @@ -87,17 +88,19 @@ class CipherMetaStorage { // Get client keys from shared server. void GetClientKeysFromServer(const char *list_name, std::map>> *clients_keys_list); + // Get stable secure aggregation's client key from shared server. + void GetStableClientKeysFromServer(const char *list_name, + std::map>> *clients_keys_list); void GetClientIVsFromServer(const char *list_name, std::map>> *clients_ivs_list); // Get client noises from shared server. bool GetClientNoisesFromServer(const char *list_name, std::vector *cur_public_noise); // Update client fl_id to shared server. bool UpdateClientToServer(const char *list_name, const std::string &fl_id); - // Update client key to shared server. - bool UpdateClientKeyToServer(const char *list_name, const std::string &fl_id, - const std::vector> &cur_public_key); // Update client key with signature to shared server. bool UpdateClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req); + // Update stable secure aggregation's client key to shared server. + bool UpdateStableClientKeyToServer(const char *list_name, const schema::RequestExchangeKeys *exchange_keys_req); // Update client noise to shared server. bool UpdateClientNoiseToServer(const char *list_name, const std::vector &cur_public_noise); // Update client share to shared server. diff --git a/mindspore/ccsrc/fl/armour/secure_protocol/key_agreement.cc b/mindspore/ccsrc/fl/armour/secure_protocol/key_agreement.cc index 8cbacfc0d78..4180071df42 100644 --- a/mindspore/ccsrc/fl/armour/secure_protocol/key_agreement.cc +++ b/mindspore/ccsrc/fl/armour/secure_protocol/key_agreement.cc @@ -66,7 +66,7 @@ int PrivateKey::GetPrivateBytes(size_t *len, uint8_t *privKeyBytes) const { } int PrivateKey::GetPublicBytes(size_t *len, uint8_t *pubKeyBytes) const { - if (pubKeyBytes == nullptr || len == nullptr || evpPrivKey == nullptr) { + if (evpPrivKey == nullptr) { MS_LOG(ERROR) << "input pubKeyBytes invalid."; return -1; } diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc index 8c52031ebde..b0b80c1331a 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc @@ -125,7 +125,10 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; return {}; } - + if (metadata_.count(name) == 0) { + MS_LOG(ERROR) << "The metadata of " << name << " is not registered."; + return {}; + } uint32_t stored_rank = router_->Find(name); MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank; if (local_rank_ == stored_rank) { @@ -257,42 +260,15 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P } else if (meta.has_prime()) { metadata_[name] = meta; } else if (meta.has_pair_client_keys()) { - auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys(); - auto &fl_id = meta.pair_client_keys().fl_id(); - auto &client_keys = meta.pair_client_keys().client_keys(); - // Check whether the new item already exists. - bool add_flag = true; - for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) { - if (fl_id == iter->first) { - add_flag = false; - MS_LOG(ERROR) << "Leader server updating value for " << name - << " failed: The Protobuffer of this value already exists."; - break; - } - } - if (add_flag) { - client_keys_map[fl_id] = client_keys; - } else { + bool keys_update_succeed = UpdatePairClientKeys(name, meta); + if (!keys_update_succeed) { + MS_LOG(ERROR) << "Update pair_client_keys failed."; return false; } } else if (meta.has_pair_client_shares()) { - auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares(); - auto &fl_id = meta.pair_client_shares().fl_id(); - auto &client_shares = meta.pair_client_shares().client_shares(); - // google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter; - // Check whether the new item already exists. - bool add_flag = true; - for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) { - if (fl_id == iter->first) { - add_flag = false; - MS_LOG(ERROR) << "Leader server updating value for " << name - << " failed: The Protobuffer of this value already exists."; - break; - } - } - if (add_flag) { - client_shares_map[fl_id] = client_shares; - } else { + bool shares_update_succeed = UpdatePairClientShares(name, meta); + if (!shares_update_succeed) { + MS_LOG(ERROR) << "Update pair_client_shares failed."; return false; } } else if (meta.has_one_client_noises()) { @@ -310,6 +286,51 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P } return true; } + +bool DistributedMetadataStore::UpdatePairClientKeys(const std::string &name, const PBMetadata &meta) { + auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys(); + auto &fl_id = meta.pair_client_keys().fl_id(); + auto &client_keys = meta.pair_client_keys().client_keys(); + // Check whether the new item already exists. + bool add_flag = true; + for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); ++iter) { + if (fl_id == iter->first) { + add_flag = false; + MS_LOG(ERROR) << "Leader server updating value for " << name + << " failed: The Protobuffer of this value already exists."; + break; + } + } + if (add_flag) { + client_keys_map[fl_id] = client_keys; + return true; + } else { + return false; + } +} + +bool DistributedMetadataStore::UpdatePairClientShares(const std::string &name, const PBMetadata &meta) { + auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares(); + auto &fl_id = meta.pair_client_shares().fl_id(); + auto &client_shares = meta.pair_client_shares().client_shares(); + // google::protobuf::Map< std::string, mindspore::fl::ps::core::SharesPb >::const_iterator iter; + // Check whether the new item already exists. + bool add_flag = true; + for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); ++iter) { + if (fl_id == iter->first) { + add_flag = false; + MS_LOG(ERROR) << "Leader server updating value for " << name + << " failed: The Protobuffer of this value already exists."; + break; + } + } + if (add_flag) { + client_shares_map[fl_id] = client_shares; + return true; + } else { + return false; + } +} } // namespace server } // namespace fl } // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.h b/mindspore/ccsrc/fl/server/distributed_metadata_store.h index d74b2160a3c..743ecf33913 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.h +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.h @@ -88,6 +88,12 @@ class DistributedMetadataStore { // Do updating metadata in the server where the metadata for the name is stored. bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta); + // Update client keys stored in server + bool UpdatePairClientKeys(const std::string &name, const PBMetadata &meta); + + // Update client shares stored in server + bool UpdatePairClientShares(const std::string &name, const PBMetadata &meta); + // Members for the communication between servers. std::shared_ptr server_node_; std::shared_ptr communicator_; diff --git a/mindspore/ccsrc/fl/worker/fl_worker.cc b/mindspore/ccsrc/fl/worker/fl_worker.cc index 836fc264bca..fd821e3b10e 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.cc +++ b/mindspore/ccsrc/fl/worker/fl_worker.cc @@ -19,6 +19,7 @@ #include #include #include "fl/worker/fl_worker.h" +#include "fl/armour/secure_protocol/key_agreement.h" #include "utils/ms_exception.h" namespace mindspore { @@ -173,8 +174,26 @@ uint64_t FLWorker::fl_iteration_num() const { return iteration_num_.load(); } void FLWorker::set_data_size(int data_size) { data_size_ = data_size; } +void FLWorker::set_secret_pk(armour::PrivateKey *secret_pk) { secret_pk_ = secret_pk; } + +void FLWorker::set_pw_salt(std::vector pw_salt) { pw_salt_ = pw_salt; } + +void FLWorker::set_pw_iv(std::vector pw_iv) { pw_iv_ = pw_iv; } + +void FLWorker::set_public_keys_list(std::vector public_keys_list) { + public_keys_list_ = public_keys_list; +} + int FLWorker::data_size() const { return data_size_; } +armour::PrivateKey *FLWorker::secret_pk() const { return secret_pk_; } + +std::vector FLWorker::pw_salt() const { return pw_salt_; } + +std::vector FLWorker::pw_iv() const { return pw_iv_; } + +std::vector FLWorker::public_keys_list() const { return public_keys_list_; } + std::string FLWorker::fl_name() const { return ps::kServerModeFL; } std::string FLWorker::fl_id() const { return std::to_string(rank_id_); } diff --git a/mindspore/ccsrc/fl/worker/fl_worker.h b/mindspore/ccsrc/fl/worker/fl_worker.h index f886a81dd82..912b0b97b59 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.h +++ b/mindspore/ccsrc/fl/worker/fl_worker.h @@ -23,11 +23,19 @@ #include "proto/comm.pb.h" #include "schema/fl_job_generated.h" #include "schema/cipher_generated.h" +#include "fl/armour/secure_protocol/key_agreement.h" #include "ps/ps_context.h" #include "ps/core/worker_node.h" #include "ps/core/cluster_metadata.h" #include "ps/core/communicator/tcp_communicator.h" +struct EncryptPublicKeys { + std::string flID; + std::vector publicKey; + std::vector pwIV; + std::vector pwSalt; +}; + namespace mindspore { namespace fl { using FBBuilder = flatbuffers::FlatBufferBuilder; @@ -88,6 +96,18 @@ class FLWorker { void set_data_size(int data_size); int data_size() const; + void set_secret_pk(armour::PrivateKey *secret_pk); + armour::PrivateKey *secret_pk() const; + + void set_pw_salt(std::vector pw_salt); + std::vector pw_salt() const; + + void set_pw_iv(std::vector pw_iv); + std::vector pw_iv() const; + + void set_public_keys_list(std::vector public_keys_list); + std::vector public_keys_list() const; + std::string fl_name() const; std::string fl_id() const; @@ -105,7 +125,11 @@ class FLWorker { worker_step_num_per_iteration_(1), server_iteration_state_(IterationState::kCompleted), worker_iteration_state_(IterationState::kCompleted), - safemode_(false) {} + safemode_(false), + secret_pk_(nullptr), + pw_salt_({}), + pw_iv_({}), + public_keys_list_({}) {} ~FLWorker() = default; FLWorker(const FLWorker &) = delete; FLWorker &operator=(const FLWorker &) = delete; @@ -152,6 +176,18 @@ class FLWorker { // The flag that represents whether worker is in safemode, which is decided by both worker and server iteration state. std::atomic_bool safemode_; + + // The private key used for computing the pairwise encryption's secret. + armour::PrivateKey *secret_pk_; + + // The salt value used for generate pairwise noise. + std::vector pw_salt_; + + // The initialization vector value used for generate pairwise noise. + std::vector pw_iv_; + + // The public keys used for computing the pairwise encryption's secret. + std::vector public_keys_list_; }; } // namespace worker } // namespace fl diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 55d38453ede..84c59584ea9 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -890,6 +890,11 @@ bool StartServerAction(const ResourcePtr &res) { rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold}); rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold}); } + if (encrypt_type == ps::kStablePWEncryptType) { + MS_LOG(INFO) << "Add stable secure aggregation rounds."; + rounds_config.push_back({"exchangeKeys", true, cipher_time_window, true, exchange_keys_threshold}); + rounds_config.push_back({"getKeys", true, cipher_time_window, true, get_keys_threshold}); + } #endif fl::server::CipherConfig cipher_config = { share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold, diff --git a/mindspore/ccsrc/ps/core/communicator/communicator_base.h b/mindspore/ccsrc/ps/core/communicator/communicator_base.h index 2f435553741..c01aec8fbbc 100644 --- a/mindspore/ccsrc/ps/core/communicator/communicator_base.h +++ b/mindspore/ccsrc/ps/core/communicator/communicator_base.h @@ -56,7 +56,9 @@ enum class TcpUserCommand { kNewInstance, kQueryInstance, kEnableFLS, - kDisableFLS + kDisableFLS, + kExchangeKeys, + kGetKeys }; // CommunicatorBase is used to receive request and send response for server. diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h index cbd3e3ba3c7..61bacc42ab4 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h @@ -53,6 +53,8 @@ const std::unordered_map kUserCommandToMsgType = { {TcpUserCommand::kProceedToNextIter, "proceedToNextIter"}, {TcpUserCommand::kEndLastIter, "endLastIter"}, {TcpUserCommand::kStartFLJob, "startFLJob"}, + {TcpUserCommand::kExchangeKeys, "exchangeKeys"}, + {TcpUserCommand::kGetKeys, "getKeys"}, {TcpUserCommand::kUpdateModel, "updateModel"}, {TcpUserCommand::kGetModel, "getModel"}, {TcpUserCommand::kPushMetrics, "pushMetrics"}, diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 033af166db0..3068144d1c6 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -199,9 +199,10 @@ void PSContext::set_server_mode(const std::string &server_mode) { const std::string &PSContext::server_mode() const { return server_mode_; } void PSContext::set_encrypt_type(const std::string &encrypt_type) { - if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) { + if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType && + encrypt_type != kStablePWEncryptType) { MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or " - << kDPEncryptType << " or " << kPWEncryptType; + << kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType; return; } encrypt_type_ = encrypt_type; diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 4c255fbc7c6..60b2823cc6f 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -37,6 +37,7 @@ constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; constexpr char kDPEncryptType[] = "DP_ENCRYPT"; constexpr char kPWEncryptType[] = "PW_ENCRYPT"; +constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT"; constexpr char kNotEncryptType[] = "NOT_ENCRYPT"; // Use binary data to represent federated learning server's context so that we can judge which round resets the diff --git a/mindspore/context.py b/mindspore/context.py index 2a451b26986..af0893bdeca 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -985,11 +985,12 @@ def set_fl_context(**kwargs): client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01. dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is suggested to be 0.5~2. Default: 1.0. - encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT' or - 'PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy - protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If - 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing. - Default: 'NOT_ENCRYPT'. + encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT', + 'PW_ENCRYPT' or 'STABLE_PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied + for clients and the privacy protection effect would be determined by dp_eps, dp_delta and dp_norm_clip + as described above. If 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' + model from stealing in cross-device scenario. If 'STABLE_PW_ENCRYPT', pairwise secure aggregation would + be applied to protect clients' model from stealing in cross-silo scenario. Default: 'NOT_ENCRYPT'. config_file_path (string): Configuration file path used by recovery. Default: ''. scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202. enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true. diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index e98cdb0ab87..1d54074dc43 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -92,7 +92,7 @@ from ._quant_ops import * from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode, ConfusionMatrix, PopulationCount, UpdateState, Load, CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull, PullWeight, PushWeight, - PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc) + PushMetrics, StartFLJob, UpdateModel, GetModel, PyFunc, ExchangeKeys, GetKeys) from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, CusMatMulCubeDenseRight, CusMatMulCubeFraczLeftCast, Im2Col, NewIm2Col, diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index fa9fbf5c2b7..2376dc561ea 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -885,6 +885,40 @@ class GetModel(PrimitiveWithInfer): return mstype.float32 +class ExchangeKeys(PrimitiveWithInfer): + """ + Exchange pairwise public keys for federated learning worker. + """ + @prim_attr_register + def __init__(self): + self.add_prim_attr("primitive_target", "CPU") + self.add_prim_attr('side_effect_mem', True) + self.init_prim_io_names(inputs=[], outputs=["result"]) + + def infer_shape(self): + return [1] + + def infer_dtype(self): + return mstype.float32 + + +class GetKeys(PrimitiveWithInfer): + """ + Get pairwise public keys for federated learning worker. + """ + @prim_attr_register + def __init__(self): + self.add_prim_attr("primitive_target", "CPU") + self.add_prim_attr('side_effect_mem', True) + self.init_prim_io_names(inputs=[], outputs=["result"]) + + def infer_shape(self): + return [1] + + def infer_dtype(self): + return mstype.float32 + + class identity(Primitive): """ Makes a identify primitive, used for pynative mode. diff --git a/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py index 641c7f19c22..3340fe31478 100644 --- a/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py +++ b/tests/st/fl/cross_silo_femnist/run_cross_silo_femnist_worker.py @@ -31,6 +31,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) parser.add_argument("--local_worker_num", type=int, default=-1) parser.add_argument("--config_file_path", type=str, default="") parser.add_argument("--dataset_path", type=str, default="") +parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") args, _ = parser.parse_known_args() device_target = args.device_target @@ -47,6 +48,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration local_worker_num = args.local_worker_num config_file_path = args.config_file_path dataset_path = args.dataset_path +encrypt_type = args.encrypt_type if local_worker_num == -1: local_worker_num = worker_num @@ -73,6 +75,7 @@ for i in range(local_worker_num): cmd_worker += " --client_learning_rate=" + str(client_learning_rate) cmd_worker += " --worker_step_num_per_iteration=" + str(worker_step_num_per_iteration) cmd_worker += " --dataset_path=" + str(dataset_path) + cmd_worker += " --encrypt_type=" + str(encrypt_type) cmd_worker += " --user_id=" + str(i) cmd_worker += " > worker.log 2>&1 &" diff --git a/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py b/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py index c7981d3873e..b905c2e45f2 100644 --- a/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py +++ b/tests/st/fl/cross_silo_femnist/test_cross_silo_femnist.py @@ -58,6 +58,7 @@ parser.add_argument("--worker_step_num_per_iteration", type=int, default=65) parser.add_argument("--scheduler_manage_port", type=int, default=11202) parser.add_argument("--config_file_path", type=str, default="") parser.add_argument("--encrypt_type", type=str, default="NOT_ENCRYPT") +parser.add_argument("--cipher_time_window", type=int, default=300000) parser.add_argument("--dataset_path", type=str, default="") # The user_id is used to set each worker's dataset path. parser.add_argument("--user_id", type=str, default="0") @@ -87,6 +88,7 @@ worker_step_num_per_iteration = args.worker_step_num_per_iteration scheduler_manage_port = args.scheduler_manage_port config_file_path = args.config_file_path encrypt_type = args.encrypt_type +cipher_time_window = args.cipher_time_window dataset_path = args.dataset_path user_id = args.user_id @@ -111,7 +113,8 @@ ctx = { "worker_step_num_per_iteration": worker_step_num_per_iteration, "scheduler_manage_port": scheduler_manage_port, "config_file_path": config_file_path, - "encrypt_type": encrypt_type + "encrypt_type": encrypt_type, + "cipher_time_window": cipher_time_window, } context.set_context(mode=context.GRAPH_MODE, device_target=device_target) @@ -285,6 +288,24 @@ class StartFLJob(nn.Cell): return self.start_fl_job() +class ExchangeKeys(nn.Cell): + def __init__(self): + super(ExchangeKeys, self).__init__() + self.exchange_keys = P.ExchangeKeys() + + def construct(self): + return self.exchange_keys() + + +class GetKeys(nn.Cell): + def __init__(self): + super(GetKeys, self).__init__() + self.get_keys = P.GetKeys() + + def construct(self): + return self.get_keys() + + class UpdateAndGetModel(nn.Cell): def __init__(self, weights): super(UpdateAndGetModel, self).__init__() @@ -328,6 +349,11 @@ def train(): if context.get_fl_context("ms_role") == "MS_WORKER": start_fl_job = StartFLJob(dataset.get_dataset_size() * args.client_batch_size) start_fl_job() + if encrypt_type == "STABLE_PW_ENCRYPT": + exchange_keys = ExchangeKeys() + exchange_keys() + get_keys = GetKeys() + get_keys() for _ in range(epoch): print("step is ", epoch, flush=True)