diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 98abe8f7e89..4dfe300f726 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -227,6 +227,11 @@ set(SUB_COMP common debug pybind_api utils vm profiler ps ) +if(ENABLE_CPU AND NOT WIN32) + add_compile_definitions(ENABLE_ARMOUR) + list(APPEND SUB_COMP "armour") +endif() + foreach(_comp ${SUB_COMP}) add_subdirectory(${_comp}) string(REPLACE "/" "_" sub ${_comp}) diff --git a/mindspore/ccsrc/armour/CMakeLists.txt b/mindspore/ccsrc/armour/CMakeLists.txt index 3d9ea951d63..c4e0cfdd05d 100644 --- a/mindspore/ccsrc/armour/CMakeLists.txt +++ b/mindspore/ccsrc/armour/CMakeLists.txt @@ -1,14 +1,5 @@ file(GLOB_RECURSE ARMOUR_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") -if(NOT ENABLE_CPU OR WIN32) - list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_init.cc") - list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_keys.cc") - list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_meta_storage.cc") - list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_reconstruct.cc") - list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_shares.cc") - list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_unmask.cc") -endif() - set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema") set(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs diff --git a/mindspore/ccsrc/armour/cipher/cipher_init.h b/mindspore/ccsrc/armour/cipher/cipher_init.h index c4babe2e8d4..aa5ba057745 100644 --- a/mindspore/ccsrc/armour/cipher/cipher_init.h +++ b/mindspore/ccsrc/armour/cipher/cipher_init.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CIPHER_INIT_H -#define MINDSPORE_CIPHER_INIT_H +#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_INIT_H +#define MINDSPORE_CCSRC_ARMOUR_CIPHER_INIT_H #include #include @@ -74,4 +74,4 @@ class CipherInit { } // namespace armour } // namespace mindspore -#endif // MINDSPORE_CIPHER_COMMON_H +#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_COMMON_H diff --git a/mindspore/ccsrc/armour/cipher/cipher_keys.h b/mindspore/ccsrc/armour/cipher/cipher_keys.h index 2e2bb7b30c8..58578eed9d6 100644 --- a/mindspore/ccsrc/armour/cipher/cipher_keys.h +++ b/mindspore/ccsrc/armour/cipher/cipher_keys.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CIPHER_KEYS_H -#define MINDSPORE_CIPHER_KEYS_H +#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H +#define MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H #include #include @@ -68,4 +68,4 @@ class CipherKeys { } // namespace armour } // namespace mindspore -#endif // MINDSPORE_CIPHER_KEYS_H +#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H diff --git a/mindspore/ccsrc/armour/cipher/cipher_meta_storage.h b/mindspore/ccsrc/armour/cipher/cipher_meta_storage.h index acf3b58c629..85c7f85ada6 100644 --- a/mindspore/ccsrc/armour/cipher/cipher_meta_storage.h +++ b/mindspore/ccsrc/armour/cipher/cipher_meta_storage.h @@ -14,16 +14,18 @@ * limitations under the License. */ -#ifndef MINDSPORE_CIPHER_META_STORAGE_H -#define MINDSPORE_CIPHER_META_STORAGE_H +#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_META_STORAGE_H +#define MINDSPORE_CCSRC_ARMOUR_CIPHER_META_STORAGE_H -#include #include #include #include #include #include #include +#ifndef _WIN32 +#include +#endif #include "proto/ps.pb.h" #include "utils/log_adapter.h" #include "armour/secure_protocol/secret_sharing.h" @@ -52,7 +54,7 @@ struct CipherPublicPara { float dp_eps; float dp_delta; float dp_norm_clip; - int encrypt_type; + string encrypt_type; }; class CipherMetaStorage { @@ -89,4 +91,4 @@ class CipherMetaStorage { } // namespace armour } // namespace mindspore -#endif // MINDSPORE_CIPHER_META_STORAGE_H +#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_META_STORAGE_H diff --git a/mindspore/ccsrc/armour/cipher/cipher_reconstruct.cc b/mindspore/ccsrc/armour/cipher/cipher_reconstruct.cc index 9771797c6c8..aacaa55c8a4 100644 --- a/mindspore/ccsrc/armour/cipher/cipher_reconstruct.cc +++ b/mindspore/ccsrc/armour/cipher/cipher_reconstruct.cc @@ -29,6 +29,10 @@ bool CipherReconStruct::CombineMask( const std::map> &reconstruct_secret_list, const std::vector &client_list) { bool retcode = true; +#ifdef _WIN32 + MS_LOG(ERROR) << "Unsupported feature in Windows platform."; + retcode = false; +#else for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) { // define flag_share: judge we need b or s bool flag_share = true; @@ -45,7 +49,6 @@ bool CipherReconStruct::CombineMask( mpz_init(prime); auto publicparam_ = CipherInit::GetInstance().GetPublicParams(); mpz_import(prime, PRIME_MAX_LEN, 1, 1, 0, 0, publicparam_->prime); - if (iter->second.size() >= cipher_init_->secrets_minnums_) { // combine private key seed. MS_LOG(INFO) << "start assign secrets shares to public shares "; for (int i = 0; i < static_cast(cipher_init_->secrets_minnums_); ++i) { @@ -89,7 +92,7 @@ bool CipherReconStruct::CombineMask( } } } - +#endif return retcode; } diff --git a/mindspore/ccsrc/armour/cipher/cipher_reconstruct.h b/mindspore/ccsrc/armour/cipher/cipher_reconstruct.h index 78146ce22bb..8806a17e83c 100644 --- a/mindspore/ccsrc/armour/cipher/cipher_reconstruct.h +++ b/mindspore/ccsrc/armour/cipher/cipher_reconstruct.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CIPHER_RECONSTRUCT_H -#define MINDSPORE_CIPHER_RECONSTRUCT_H +#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_RECONSTRUCT_H +#define MINDSPORE_CCSRC_ARMOUR_CIPHER_RECONSTRUCT_H #include #include @@ -84,4 +84,4 @@ class CipherReconStruct { } // namespace armour } // namespace mindspore -#endif // MINDSPORE_CIPHER_KEYS_H +#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H diff --git a/mindspore/ccsrc/armour/cipher/cipher_shares.h b/mindspore/ccsrc/armour/cipher/cipher_shares.h index add86c44950..79ac512d17b 100644 --- a/mindspore/ccsrc/armour/cipher/cipher_shares.h +++ b/mindspore/ccsrc/armour/cipher/cipher_shares.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CIPHER_SHARES_H -#define MINDSPORE_CIPHER_SHARES_H +#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_SHARES_H +#define MINDSPORE_CCSRC_ARMOUR_CIPHER_SHARES_H #include #include @@ -65,4 +65,4 @@ class CipherShares { } // namespace armour } // namespace mindspore -#endif // MINDSPORE_CIPHER_KEYS_H +#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_SHARES_H diff --git a/mindspore/ccsrc/armour/cipher/cipher_unmask.h b/mindspore/ccsrc/armour/cipher/cipher_unmask.h index 85ddbb421f9..3728edc74b1 100644 --- a/mindspore/ccsrc/armour/cipher/cipher_unmask.h +++ b/mindspore/ccsrc/armour/cipher/cipher_unmask.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CIPHER_UNMASK_H -#define MINDSPORE_CIPHER_UNMASK_H +#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_UNMASK_H +#define MINDSPORE_CCSRC_ARMOUR_CIPHER_UNMASK_H #include #include @@ -42,4 +42,4 @@ class CipherUnmask { } // namespace armour } // namespace mindspore -#endif // MINDSPORE_CIPHER_KEYS_H +#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_UNMASK_H diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 3cdf652cdd8..dcdccb038e9 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -672,14 +672,21 @@ bool StartServerAction(const ResourcePtr &res) { {"pullWeight"}, {"pushWeight", false, 3000, true, server_num, true}}; + float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio(); + float get_model_ratio = ps::PSContext::instance()->get_model_ratio(); + size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold(); + + ps::server::CipherConfig cipher_config = {share_secrets_ratio, get_model_ratio, reconstruct_secrets_threshhold}; + size_t executor_threshold = 0; if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { executor_threshold = update_model_threshold; - ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, func_graph, + ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph, executor_threshold); } else if (server_mode_ == ps::kServerModePS) { executor_threshold = worker_num; - ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, func_graph, executor_threshold); + ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph, + executor_threshold); } else { MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported."; return false; diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index df9c03811bc..eba2638b702 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -354,6 +354,11 @@ PYBIND11_MODULE(_c_expression, m) { "Set threshold count ratio for updateModel round.") .def("set_update_model_time_window", &PSContext::set_update_model_time_window, "Set time window for updateModel round.") + .def("set_share_secrets_ratio", &PSContext::set_share_secrets_ratio, + "Set threshold count ratio for share secrets round.") + .def("set_get_model_ratio", &PSContext::set_get_model_ratio, "Set threshold count ratio for get model round.") + .def("set_reconstruct_secrets_threshhold", &PSContext::set_reconstruct_secrets_threshhold, + "Set threshold count for reconstruct secrets round.") .def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.") .def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.") .def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.") diff --git a/mindspore/ccsrc/ps/core/protos/fl.proto b/mindspore/ccsrc/ps/core/protos/fl.proto index 5dd0c7ffd80..b338e067e08 100644 --- a/mindspore/ccsrc/ps/core/protos/fl.proto +++ b/mindspore/ccsrc/ps/core/protos/fl.proto @@ -105,6 +105,10 @@ message ClientNoises { OneClientNoises one_client_noises = 1; } +message PrimeList { + repeated bytes prime = 1; +} + message PairClientKeys { string fl_id = 1; KeysPb client_keys = 2; @@ -128,6 +132,10 @@ message KeysPb { repeated bytes key = 1; } +message Prime { + bytes prime = 1; +} + message PBMetadata { oneof value { DeviceMeta device_meta = 1; @@ -146,6 +154,9 @@ message PBMetadata { OneClientNoises one_client_noises = 10; ClientNoises client_noises = 11; + + Prime prime = 12; + PrimeList prime_list = 13; } } diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 87166202ad6..8478e901440 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -294,6 +294,20 @@ void PSContext::set_update_model_time_window(uint64_t update_model_time_window) uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; } +void PSContext::set_share_secrets_ratio(float share_secrets_ratio) { share_secrets_ratio_ = share_secrets_ratio; } + +float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; } + +void PSContext::set_get_model_ratio(float get_model_ratio) { get_model_ratio_ = get_model_ratio; } + +float PSContext::get_model_ratio() const { return get_model_ratio_; } + +void PSContext::set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold) { + reconstruct_secrets_threshhold_ = reconstruct_secrets_threshhold; +} + +uint64_t PSContext::reconstruct_secrets_threshhold() const { return reconstruct_secrets_threshhold_; } + void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } const std::string &PSContext::fl_name() const { return fl_name_; } diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 8e91bf13bb0..cd9c5407524 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -35,6 +35,9 @@ constexpr char kEnvRoleOfServer[] = "MS_SERVER"; constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; +constexpr char kDPEncryptType[] = "DPEncrypt"; +constexpr char kPWEncryptType[] = "PWEncrypt"; +constexpr char kNotEncryptType[] = "NotEncrypt"; // Use binary data to represent federated learning server's context so that we can judge which round resets the // iteration. From right to left, each bit stands for: @@ -125,6 +128,15 @@ class PSContext { void set_update_model_time_window(uint64_t update_model_time_window); uint64_t update_model_time_window() const; + void set_share_secrets_ratio(float share_secrets_ratio); + float share_secrets_ratio() const; + + void set_get_model_ratio(float get_model_ratio); + float get_model_ratio() const; + + void set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold); + uint64_t reconstruct_secrets_threshhold() const; + void set_fl_name(const std::string &fl_name); const std::string &fl_name() const; @@ -173,6 +185,9 @@ class PSContext { start_fl_job_time_window_(3000), update_model_ratio_(1.0), update_model_time_window_(3000), + share_secrets_ratio_(1.0), + get_model_ratio_(1.0), + reconstruct_secrets_threshhold_(2000), fl_iteration_num_(20), client_epoch_num_(25), client_batch_size_(32), @@ -222,6 +237,15 @@ class PSContext { // The time window of updateModel round in millisecond. uint64_t update_model_time_window_; + // Share model threshold is a certain ratio of share secrets threshold which is set as share_secrets_ratio_. + float share_secrets_ratio_; + + // Get model threshold is a certain ratio of get model threshold which is set as get_model_ratio_. + float get_model_ratio_; + + // The threshold count of reconstruct secrets round. Used in federated learning for now. + uint64_t reconstruct_secrets_threshhold_; + // Iteration number of federeated learning, which is the number of interactions between client and server. uint64_t fl_iteration_num_; diff --git a/mindspore/ccsrc/ps/server/common.h b/mindspore/ccsrc/ps/server/common.h index 08772175ea5..e95edabe554 100644 --- a/mindspore/ccsrc/ps/server/common.h +++ b/mindspore/ccsrc/ps/server/common.h @@ -59,6 +59,12 @@ struct RoundConfig { bool server_num_as_threshold = false; }; +struct CipherConfig { + float share_secrets_ratio = 1.0; + float get_model_ratio = 1.0; + size_t reconstruct_secrets_threshhold = 0; +}; + using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; using mindspore::kernel::CPUKernel; @@ -175,9 +181,17 @@ constexpr auto kCtxDeviceMetas = "device_metas"; constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration"; constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp"; constexpr auto kCtxUpdateModelClientList = "update_model_client_list"; -constexpr auto kCtxUpdateModelClientNum = "update_model_client_num"; constexpr auto kCtxUpdateModelThld = "update_model_threshold"; +constexpr auto kCtxUpdateModelClientNum = "update_model_client_num"; +constexpr auto kCtxClientsKeys = "clients_keys"; +constexpr auto kCtxClientNoises = "clients_noises"; +constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares"; +constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares"; +constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list"; +constexpr auto kCtxReconstructClientList = "reconstruct_client_list"; +constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list"; constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size"; +constexpr auto kCtxCipherPrimer = "cipher_primer"; // This macro the current timestamp in milliseconds. #define CURRENT_TIME_MILLI \ diff --git a/mindspore/ccsrc/ps/server/distributed_metadata_store.cc b/mindspore/ccsrc/ps/server/distributed_metadata_store.cc index ea64df883c8..43d14d341bb 100644 --- a/mindspore/ccsrc/ps/server/distributed_metadata_store.cc +++ b/mindspore/ccsrc/ps/server/distributed_metadata_store.cc @@ -238,6 +238,63 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P } else if (meta.has_update_model_threshold()) { auto update_model_threshold = metadata_[name].mutable_update_model_threshold(); *update_model_threshold = meta.update_model_threshold(); + } else if (meta.has_prime()) { + auto prime_list = metadata_[name].mutable_prime_list(); + auto &prime_id = meta.prime().prime(); + if (prime_list->prime_size() == 0) { + prime_list->add_prime(prime_id); + } + } else if (meta.has_pair_client_keys()) { + auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys(); + auto &fl_id = meta.pair_client_keys().fl_id(); + auto &client_keys = meta.pair_client_keys().client_keys(); + // Check whether the new item already exists. + bool add_flag = true; + for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); iter++) { + if (fl_id == iter->first) { + add_flag = false; + MS_LOG(ERROR) << "Leader server updating value for " << name + << " failed: The Protobuffer of this value already exists."; + break; + } + } + if (add_flag) { + client_keys_map[fl_id] = client_keys; + } else { + return false; + } + } else if (meta.has_pair_client_shares()) { + auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares(); + auto &fl_id = meta.pair_client_shares().fl_id(); + auto &client_shares = meta.pair_client_shares().client_shares(); + // google::protobuf::Map< std::string, mindspore::ps::core::SharesPb >::const_iterator iter; + // Check whether the new item already exists. + bool add_flag = true; + for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); iter++) { + if (fl_id == iter->first) { + add_flag = false; + MS_LOG(ERROR) << "Leader server updating value for " << name + << " failed: The Protobuffer of this value already exists."; + break; + } + } + if (add_flag) { + client_shares_map[fl_id] = client_shares; + } else { + return false; + } + } else if (meta.has_one_client_noises()) { + auto &client_noises = *metadata_[name].mutable_client_noises(); + if (client_noises.has_one_client_noises()) { + MS_LOG(ERROR) << "Leader server updating value for " << name + << " failed: The Protobuffer of this value already exists."; + client_noises.Clear(); + } + client_noises.mutable_one_client_noises()->MergeFrom(meta.one_client_noises()); + } else { + MS_LOG(ERROR) << "Leader server updating value for " << name + << " failed: The Protobuffer of this value is not defined."; + return false; } return true; } diff --git a/mindspore/ccsrc/ps/server/executor.cc b/mindspore/ccsrc/ps/server/executor.cc index af392075e70..ce2c5602afb 100644 --- a/mindspore/ccsrc/ps/server/executor.cc +++ b/mindspore/ccsrc/ps/server/executor.cc @@ -268,10 +268,14 @@ std::map Executor::GetModel() { return model; } -// bool Executor::Unmask() { -// auto model = GetModel(); -// return mindarmour::CipherMgr::GetInstance().UnMask(model); -// } +bool Executor::Unmask() { +#ifdef ENABLE_ARMOUR + auto model = GetModel(); + return cipher_unmask_.UnMask(model); +#else + return false; +#endif +} const std::vector &Executor::param_names() const { return param_names_; } diff --git a/mindspore/ccsrc/ps/server/executor.h b/mindspore/ccsrc/ps/server/executor.h index 0a43546a732..893d34d8b70 100644 --- a/mindspore/ccsrc/ps/server/executor.h +++ b/mindspore/ccsrc/ps/server/executor.h @@ -26,6 +26,9 @@ #include #include "ps/server/common.h" #include "ps/server/parameter_aggregator.h" +#ifdef ENABLE_ARMOUR +#include "armour/cipher/cipher_unmask.h" +#endif namespace mindspore { namespace ps { @@ -88,6 +91,7 @@ class Executor { bool initialized() const; const std::vector ¶m_names() const; + bool Unmask(); private: Executor() {} @@ -117,6 +121,9 @@ class Executor { // Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can // acquire lock before calling its method. std::map parameter_mutex_; +#ifdef ENABLE_ARMOUR + armour::CipherUnmask cipher_unmask_; +#endif }; } // namespace server } // namespace ps diff --git a/mindspore/ccsrc/ps/server/kernel/round/client_list_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/client_list_kernel.cc new file mode 100644 index 00000000000..6f0d003cd81 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/client_list_kernel.cc @@ -0,0 +1,199 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/client_list_kernel.h" +#include +#include +#include +#include +#include "schema/cipher_generated.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +void ClientListKernel::InitKernel(size_t) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + + executor_ = &Executor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor_); + if (!executor_->initialized()) { + MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; + return; + } + cipher_init_ = &armour::CipherInit::GetInstance(); +} + +bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req, + std::shared_ptr fbb) { + bool response = false; + std::vector client_list; + std::string fl_id = get_clients_req->fl_id()->str(); + int32_t iter_client = (size_t)get_clients_req->iteration(); + if (iter_num != (size_t)iter_client) { + MS_LOG(ERROR) << "ClientListKernel iteration invalid. servertime is " << iter_num; + MS_LOG(ERROR) << "ClientListKernel iteration invalid. clienttime is " << iter_client; + BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) { + uint64_t update_model_client_num = LocalMetaStore::GetInstance().value(kCtxUpdateModelThld); + PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList); + const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list(); + for (int i = 0; i < client_list_pb.fl_id_size(); ++i) { + client_list.push_back(client_list_pb.fl_id(i)); + } + if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) { // client in client_list. + if (static_cast(client_list_pb.fl_id_size()) >= update_model_client_num) { + MS_LOG(INFO) << "send clients_list succeed!"; + MS_LOG(INFO) << "UpdateModel client list: "; + for (size_t i = 0; i < client_list.size(); ++i) { + MS_LOG(INFO) << " fl_id : " << client_list[i]; + } + MS_LOG(INFO) << "update_model_client_num: " << update_model_client_num; + BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + response = true; + } else { + MS_LOG(INFO) << "The server is not ready. update_model_client_need_num: " << update_model_client_num; + MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size(); + /*for (size_t i = 0; i < std::min(client_list.size(), size_t(2)); ++i) { + MS_LOG(INFO) << " client_list fl_id : " << client_list[i]; + } + for (size_t i = client_list.size() - size_t(1); i > std::max(client_list.size() - size_t(2), size_t(0)); + --i) { + MS_LOG(INFO) << " client_list fl_id : " << client_list[i]; + }*/ + int count_tmp = 0; + for (size_t i = 0; i < cipher_init_->get_model_num_need_; ++i) { + size_t j = 0; + for (; j < client_list.size(); ++j) { + if (("f" + std::to_string(i)) == client_list[j]) break; + } + if (j >= client_list.size()) { + count_tmp++; + MS_LOG(INFO) << " no client_list fl_id : " << i; + if (count_tmp > 3) break; + } + } + BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } + } + if (response) { + DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str()); + } + } else { + MS_LOG(ERROR) << "update_model_client_num is zero."; + BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_num is zero.", client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } + } + return response; +} +bool ClientListKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + std::shared_ptr fbb = std::make_shared(); + bool response = false; + size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); + size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration; + clock_t start_time = clock(); + + std::vector client_list; + if (inputs.size() != 1) { + MS_LOG(ERROR) << "ClientListKernel needs 1 input,but got " << inputs.size(); + BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel input num not match", client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else if (outputs.size() != 1) { + MS_LOG(ERROR) << "ClientListKernel needs 1 output,but got " << outputs.size(); + BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel output num not match", client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + MS_LOG(ERROR) << "Current amount for GetClientList is enough."; + BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "ClientListKernel num is enough", client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + void *req_data = inputs[0]->addr; + const schema::GetClientList *get_clients_req = flatbuffers::GetRoot(req_data); + + if (get_clients_req == nullptr || fbb == nullptr) { + MS_LOG(ERROR) << "GetClientList is nullptr or ClientListRsp builder is nullptr."; + BuildClientListRsp(fbb, schema::ResponseCode_RequestError, + "GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + response = DealClient(iter_num, get_clients_req, fbb); + } + } + } + + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + clock_t end_time = clock(); + double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); + MS_LOG(INFO) << "client_list_kernel success time is : " << duration; + return response; +} // namespace ps + +bool ClientListKernel::Reset() { + MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + MS_LOG(INFO) << "Get Client list kernel reset!"; + DistributedCountService::GetInstance().ResetCounter(name_); + StopTimer(); + return true; +} + +void ClientListKernel::BuildClientListRsp(std::shared_ptr client_list_resp_builder, + const schema::ResponseCode retcode, const string &reason, + std::vector clients, const string &next_req_time, + const int iteration) { + auto rsp_reason = client_list_resp_builder->CreateString(reason); + auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time); + if (clients.size() > 0) { + std::vector> clients_vector; + for (auto client : clients) { + auto client_fb = client_list_resp_builder->CreateString(client); + clients_vector.push_back(client_fb); + } + auto clients_fb = client_list_resp_builder->CreateVector(clients_vector); + schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get())); + rsp_builder.add_retcode(retcode); + rsp_builder.add_reason(rsp_reason); + rsp_builder.add_clients(clients_fb); + rsp_builder.add_iteration(iteration); + rsp_builder.add_next_req_time(rsp_next_req_time); + auto rsp_exchange_keys = rsp_builder.Finish(); + client_list_resp_builder->Finish(rsp_exchange_keys); + } else { + schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get())); + rsp_builder.add_retcode(retcode); + rsp_builder.add_reason(rsp_reason); + rsp_builder.add_iteration(iteration); + rsp_builder.add_next_req_time(rsp_next_req_time); + auto rsp_exchange_keys = rsp_builder.Finish(); + client_list_resp_builder->Finish(rsp_exchange_keys); + } + return; +} + +REG_ROUND_KERNEL(getClientList, ClientListKernel) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/client_list_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/client_list_kernel.h new file mode 100644 index 00000000000..36992fe7598 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/client_list_kernel.h @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/kernel/round/round_kernel.h" +#include "ps/server/kernel/round/round_kernel_factory.h" +#include "armour/cipher/cipher_init.h" +#include "ps/server/executor.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +class ClientListKernel : public RoundKernel { + public: + ClientListKernel() = default; + ~ClientListKernel() override = default; + void InitKernel(size_t required_cnt) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + bool Reset() override; + void BuildClientListRsp(std::shared_ptr client_list_resp_builder, + const schema::ResponseCode retcode, const string &reason, std::vector clients, + const string &next_req_time, const int iteration); + + private: + armour::CipherInit *cipher_init_; + bool DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req, + std::shared_ptr fbb); + Executor *executor_; + size_t iteration_time_window_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H diff --git a/mindspore/ccsrc/ps/server/kernel/round/exchange_keys_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/exchange_keys_kernel.cc new file mode 100644 index 00000000000..98f1251cc17 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/exchange_keys_kernel.cc @@ -0,0 +1,101 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/exchange_keys_kernel.h" +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +void ExchangeKeysKernel::InitKernel(size_t) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + + executor_ = &Executor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor_); + if (!executor_->initialized()) { + MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; + return; + } + + cipher_key_ = &armour::CipherKeys::GetInstance(); +} + +bool ExchangeKeysKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + std::shared_ptr fbb = std::make_shared(); + bool response = false; + size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); + size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is " + << total_duration; + clock_t start_time = clock(); + + if (inputs.size() != 1) { + MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 input,but got " << inputs.size(); + cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel input num not match", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else if (outputs.size() != 1) { + MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 output,but got " << outputs.size(); + cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel output num not match", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + MS_LOG(ERROR) << "Current amount for ExchangeKeysKernel is enough."; + cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, + "Current amount for ExchangeKeysKernel is enough.", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + void *req_data = inputs[0]->addr; + const schema::RequestExchangeKeys *exchange_keys_req = + flatbuffers::GetRoot(req_data); + int32_t iter_client = (size_t)exchange_keys_req->iteration(); + if (iter_num != (size_t)iter_client) { + MS_LOG(ERROR) << "ExchangeKeysKernel iteration invalid. server now iteration is " << iter_num + << ". client request iteration is " << iter_client; + cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + response = + cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb); + if (response) { + DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str()); + } + } + } + } + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + clock_t end_time = clock(); + double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); + MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration; + return response; +} + +bool ExchangeKeysKernel::Reset() { + MS_LOG(INFO) << "exchange keys kernel reset, ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + DistributedCountService::GetInstance().ResetCounter(name_); + StopTimer(); + return true; +} +REG_ROUND_KERNEL(exchangeKeys, ExchangeKeysKernel) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/exchange_keys_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/exchange_keys_kernel.h new file mode 100644 index 00000000000..977b140eebc --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/exchange_keys_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H + +#include +#include "ps/server/common.h" +#include "ps/server/kernel/round/round_kernel.h" +#include "ps/server/kernel/round/round_kernel_factory.h" +#include "ps/server/executor.h" +#include "armour/cipher/cipher_keys.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +class ExchangeKeysKernel : public RoundKernel { + public: + ExchangeKeysKernel() = default; + ~ExchangeKeysKernel() override = default; + void InitKernel(size_t required_cnt) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + bool Reset() override; + + private: + Executor *executor_; + size_t iteration_time_window_; + armour::CipherKeys *cipher_key_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H diff --git a/mindspore/ccsrc/ps/server/kernel/round/get_keys_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/get_keys_kernel.cc new file mode 100644 index 00000000000..c32ab7d9e27 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/get_keys_kernel.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/get_keys_kernel.h" +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +void GetKeysKernel::InitKernel(size_t) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + + executor_ = &Executor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor_); + if (!executor_->initialized()) { + MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; + return; + } + + cipher_key_ = &armour::CipherKeys::GetInstance(); +} + +bool GetKeysKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + std::shared_ptr fbb = std::make_shared(); + bool response = false; + size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); + size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetKeysKernel allowed Duration Is " + << total_duration; + clock_t start_time = clock(); + + if (inputs.size() != 1) { + MS_LOG(ERROR) << "GetKeysKernel needs 1 input,but got " << inputs.size(); + cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num, + std::to_string(CURRENT_TIME_MILLI.count()), false); + } else if (outputs.size() != 1) { + MS_LOG(ERROR) << "GetKeysKernel needs 1 output,but got " << outputs.size(); + cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num, + std::to_string(CURRENT_TIME_MILLI.count()), false); + } else { + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough."; + cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num, + std::to_string(CURRENT_TIME_MILLI.count()), false); + } else { + void *req_data = inputs[0]->addr; + const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot(req_data); + int32_t iter_client = (size_t)get_exchange_keys_req->iteration(); + if (iter_num != (size_t)iter_client) { + MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num + << ". client request iteration is " << iter_client; + cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num, + std::to_string(CURRENT_TIME_MILLI.count()), false); + } else { + response = + cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb); + if (response) { + DistributedCountService::GetInstance().Count(name_, get_exchange_keys_req->fl_id()->str()); + } + } + } + } + GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize()); + clock_t end_time = clock(); + double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); + MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration; + return response; +} + +bool GetKeysKernel::Reset() { + MS_LOG(INFO) << "get keys kernel reset! ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + cipher_key_->ClearKeys(); + DistributedCountService::GetInstance().ResetCounter(name_); + StopTimer(); + return true; +} + +REG_ROUND_KERNEL(getKeys, GetKeysKernel) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/get_keys_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/get_keys_kernel.h new file mode 100644 index 00000000000..3902bb664b1 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/get_keys_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H + +#include +#include "ps/server/common.h" +#include "ps/server/kernel/round/round_kernel.h" +#include "ps/server/kernel/round/round_kernel_factory.h" +#include "ps/server/executor.h" +#include "armour/cipher/cipher_keys.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +class GetKeysKernel : public RoundKernel { + public: + GetKeysKernel() = default; + ~GetKeysKernel() override = default; + void InitKernel(size_t required_cnt) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + bool Reset() override; + + private: + Executor *executor_; + size_t iteration_time_window_; + armour::CipherKeys *cipher_key_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H diff --git a/mindspore/ccsrc/ps/server/kernel/round/get_secrets_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/get_secrets_kernel.cc new file mode 100644 index 00000000000..bb1dc2c084a --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/get_secrets_kernel.cc @@ -0,0 +1,103 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/get_secrets_kernel.h" +#include +#include +#include +#include "armour/cipher/cipher_shares.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +void GetSecretsKernel::InitKernel(size_t) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + + executor_ = &Executor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor_); + if (!executor_->initialized()) { + MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; + return; + } + + cipher_share_ = &armour::CipherShares::GetInstance(); +} + +bool GetSecretsKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + bool response = false; + std::shared_ptr fbb = std::make_shared(); + + size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); + MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count()); + size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is " + << total_duration; + + clock_t start_time = clock(); + + if (inputs.size() != 1) { + MS_LOG(ERROR) << "GetSecretsKernel needs 1 input,but got " << inputs.size(); + cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0); + } else if (outputs.size() != 1) { + MS_LOG(ERROR) << "GetSecretsKernel needs 1 output,but got " << outputs.size(); + cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0); + } else { + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough."; + cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0); + } else { + void *req_data = inputs[0]->addr; + const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot(req_data); + int32_t iter_client = (size_t)get_secrets_req->iteration(); + if (iter_num != (size_t)iter_client) { + MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num + << ". client request iteration is " << iter_client; + cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0); + } else { + response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp); + if (response) { + DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str()); + } + } + } + } + + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + clock_t end_time = clock(); + double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); + MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration; + return response; +} + +bool GetSecretsKernel::Reset() { + MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + MS_LOG(INFO) << "GetSecretsKernel reset!"; + cipher_share_->ClearShareSecrets(); + DistributedCountService::GetInstance().ResetCounter(name_); + StopTimer(); + return true; +} + +REG_ROUND_KERNEL(getSecrets, GetSecretsKernel) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/get_secrets_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/get_secrets_kernel.h new file mode 100644 index 00000000000..8e9d5c5e997 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/get_secrets_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H + +#include +#include "ps/server/common.h" +#include "ps/server/kernel/round/round_kernel.h" +#include "ps/server/kernel/round/round_kernel_factory.h" +#include "armour/cipher/cipher_shares.h" +#include "ps/server/executor.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +class GetSecretsKernel : public RoundKernel { + public: + GetSecretsKernel() = default; + ~GetSecretsKernel() override = default; + void InitKernel(size_t required_cnt) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + bool Reset() override; + + private: + Executor *executor_; + size_t iteration_time_window_; + armour::CipherShares *cipher_share_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H diff --git a/mindspore/ccsrc/ps/server/kernel/round/reconstruct_secrets_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/reconstruct_secrets_kernel.cc new file mode 100644 index 00000000000..c81bf9d771e --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/reconstruct_secrets_kernel.cc @@ -0,0 +1,165 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/reconstruct_secrets_kernel.h" +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +void ReconstructSecretsKernel::InitKernel(size_t required_cnt) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + + executor_ = &Executor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor_); + if (!executor_->initialized()) { + MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; + return; + } + auto last_cnt_handler = [&](std::shared_ptr) { + MS_LOG(INFO) << "start FinishIteration"; + FinishIteration(); + MS_LOG(INFO) << "end FinishIteration"; + return; + }; + auto first_cnt_handler = [&](std::shared_ptr) { return; }; + name_unmask_ = "UnMaskKernel"; + MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : " + << LocalMetaStore::GetInstance().curr_iter_num(); + DistributedCountService::GetInstance().RegisterCounter(name_unmask_, PSContext::instance()->initial_server_num(), + {first_cnt_handler, last_cnt_handler}); +} + +bool ReconstructSecretsKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + std::shared_ptr fbb = std::make_shared(); + bool response = false; + size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); + // MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + MS_LOG(INFO) << "Iteration number is " << iter_num << ", ReconstructSecretsKernel total duration is " + << total_duration; + clock_t start_time = clock(); + + if (inputs.size() != 1) { + MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size(); + cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError, + "ReconstructSecretsKernel input num not match.", iter_num, + std::to_string(CURRENT_TIME_MILLI.count())); + } else if (outputs.size() != 1) { + MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 output, but got " << outputs.size(); + cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError, + "ReconstructSecretsKernel output num not match.", iter_num, + std::to_string(CURRENT_TIME_MILLI.count())); + } else { + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough."; + cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, + "Current amount for ReconstructSecretsKernel is enough.", iter_num, + std::to_string(CURRENT_TIME_MILLI.count())); + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return false; + } + + void *req_data = inputs[0]->addr; + const schema::SendReconstructSecret *reconstruct_secret_req = + flatbuffers::GetRoot(req_data); + + // get client list from memory server. + std::vector client_list; + uint64_t update_model_client_num = 0; + if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) { + update_model_client_num = LocalMetaStore::GetInstance().value(kCtxUpdateModelThld); + } else { + MS_LOG(ERROR) << "update_model_client_num is zero."; + cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError, + "update_model_client_num is zero.", iter_num, + std::to_string(CURRENT_TIME_MILLI.count())); + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return false; + } + const PBMetadata client_list_pb_out = + DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList); + + const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list(); + int client_list_actual_size = client_list_pb.fl_id_size(); + if (client_list_actual_size < 0) { + client_list_actual_size = 0; + } + if (static_cast(client_list_actual_size) < update_model_client_num) { + MS_LOG(INFO) << "ReconstructSecretsKernel : client list is not ready " << inputs.size(); + cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SucNotReady, + "ReconstructSecretsKernel : client list is not ready", iter_num, + std::to_string(CURRENT_TIME_MILLI.count())); + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return response; + } + for (int i = 0; i < client_list_pb.fl_id_size(); ++i) { + client_list.push_back(client_list_pb.fl_id(i)); + } + response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), + reconstruct_secret_req, fbb, client_list); + if (response) { + // MS_LOG(INFO) << "start ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str(); + DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str()); + // MS_LOG(INFO) << "end ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str(); + } + } + + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + clock_t end_time = clock(); + double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); + MS_LOG(INFO) << "reconstruct_secrets_kernel success time is : " << duration; + return response; +} + +void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr &message) { + MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + if (true) { // todo: PSContext::instance()->encrypt_type == PWEncrypt { + while (!Executor::GetInstance().IsAllWeightAggregationDone()) { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + + MS_LOG(INFO) << "start unmask"; + while (!Executor::GetInstance().Unmask()) { + std::this_thread::sleep_for(std::chrono::milliseconds(5)); + } + MS_LOG(INFO) << "end unmask"; + std::string worker_id = std::to_string(DistributedCountService::GetInstance().local_rank()); + DistributedCountService::GetInstance().Count(name_unmask_, worker_id); + } +} + +bool ReconstructSecretsKernel::Reset() { + MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + MS_LOG(INFO) << "reconstruct secrets kernel reset!"; + DistributedCountService::GetInstance().ResetCounter(name_); + DistributedCountService::GetInstance().ResetCounter(name_unmask_); + StopTimer(); + cipher_reconstruct_.ClearReconstructSecrets(); + return true; +} + +REG_ROUND_KERNEL(reconstructSecrets, ReconstructSecretsKernel) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/reconstruct_secrets_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/reconstruct_secrets_kernel.h new file mode 100644 index 00000000000..934f6316460 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/reconstruct_secrets_kernel.h @@ -0,0 +1,54 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_ +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_ + +#include +#include +#include +#include "ps/server/common.h" +#include "ps/server/kernel/round/round_kernel.h" +#include "ps/server/kernel/round/round_kernel_factory.h" +#include "armour/cipher/cipher_reconstruct.h" +#include "ps/server/executor.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +class ReconstructSecretsKernel : public RoundKernel { + public: + ReconstructSecretsKernel() = default; + ~ReconstructSecretsKernel() override = default; + + void InitKernel(size_t required_cnt) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + bool Reset() override; + void OnLastCountEvent(const std::shared_ptr &message) override; + + private: + std::string name_unmask_; + Executor *executor_; + size_t iteration_time_window_; + armour::CipherReconStruct cipher_reconstruct_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_ diff --git a/mindspore/ccsrc/ps/server/kernel/round/share_secrets_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/share_secrets_kernel.cc new file mode 100644 index 00000000000..262c7d83265 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/share_secrets_kernel.cc @@ -0,0 +1,102 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/server/kernel/round/share_secrets_kernel.h" +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +void ShareSecretsKernel::InitKernel(size_t) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + + executor_ = &Executor::GetInstance(); + MS_EXCEPTION_IF_NULL(executor_); + if (!executor_->initialized()) { + MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; + return; + } + cipher_share_ = &armour::CipherShares::GetInstance(); +} + +bool ShareSecretsKernel::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) { + bool response = false; + std::shared_ptr fbb = std::make_shared(); + size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); + size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is " + << total_duration; + clock_t start_time = clock(); + + if (inputs.size() != 1) { + MS_LOG(ERROR) << "ShareSecretsKernel needs 1 input,but got " << inputs.size(); + cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError, "ShareSecretsKernel input num not match", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else if (outputs.size() != 1) { + MS_LOG(ERROR) << "ShareSecretsKernel needs 1 output,but got " << outputs.size(); + cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError, + "ShareSecretsKernel output num not match", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough."; + cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, + "Current amount for ShareSecretsKernel is enough.", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + void *req_data = inputs[0]->addr; + const schema::RequestShareSecrets *share_secrets_req = + flatbuffers::GetRoot(req_data); + size_t iter_client = (size_t)share_secrets_req->iteration(); + if (iter_num != iter_client) { + MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num + << ". client request iteration is " << iter_client; + cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + response = + cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count())); + if (response) { + DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str()); + } + } + } + } + + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + clock_t end_time = clock(); + double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); + MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration; + return response; +} + +bool ShareSecretsKernel::Reset() { + MS_LOG(INFO) << "share_secrets_kernel reset! ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + DistributedCountService::GetInstance().ResetCounter(name_); + StopTimer(); + return true; +} + +REG_ROUND_KERNEL(shareSecrets, ShareSecretsKernel) +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/server/kernel/round/share_secrets_kernel.h b/mindspore/ccsrc/ps/server/kernel/round/share_secrets_kernel.h new file mode 100644 index 00000000000..fd06c03f342 --- /dev/null +++ b/mindspore/ccsrc/ps/server/kernel/round/share_secrets_kernel.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H +#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H + +#include +#include "ps/server/common.h" +#include "ps/server/executor.h" +#include "ps/server/kernel/round/round_kernel.h" +#include "ps/server/kernel/round/round_kernel_factory.h" +#include "armour/cipher/cipher_shares.h" + +namespace mindspore { +namespace ps { +namespace server { +namespace kernel { +class ShareSecretsKernel : public RoundKernel { + public: + ShareSecretsKernel() = default; + ~ShareSecretsKernel() override = default; + void InitKernel(size_t required_cnt) override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + bool Reset() override; + + private: + Executor *executor_; + size_t iteration_time_window_; + armour::CipherShares *cipher_share_; +}; +} // namespace kernel +} // namespace server +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H diff --git a/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc index a12f719f452..33088d169a5 100644 --- a/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/ps/server/kernel/round/start_fl_job_kernel.cc @@ -21,6 +21,9 @@ #include #include "ps/server/model_store.h" #include "ps/server/iteration.h" +#ifdef ENABLE_ARMOUR +#include "armour/cipher/cipher_init.h" +#endif namespace mindspore { namespace ps { @@ -192,6 +195,21 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode()); auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name()); +#ifdef ENABLE_ARMOUR + auto *param = armour::CipherInit::GetInstance().GetPublicParams(); + auto prime = fbb->CreateVector(param->prime, PRIME_MAX_LEN); + auto p = fbb->CreateVector(param->p, SECRET_MAX_LEN); + int32_t t = param->t; + int32_t g = param->g; + float dp_eps = param->dp_eps; + float dp_delta = param->dp_delta; + float dp_norm_clip = param->dp_norm_clip; + auto encrypt_type = fbb->CreateString(param->encrypt_type); + + auto cipher_public_params = + schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type); +#endif + schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); fl_plan_builder.add_fl_name(fbs_fl_name); fl_plan_builder.add_server_mode(fbs_server_mode); @@ -199,6 +217,9 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num()); fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size()); fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate()); +#ifdef ENABLE_ARMOUR + fl_plan_builder.add_cipher(cipher_public_params); +#endif auto fbs_fl_plan = fl_plan_builder.Finish(); std::vector> fbs_feature_maps; diff --git a/mindspore/ccsrc/ps/server/server.cc b/mindspore/ccsrc/ps/server/server.cc index f42000af0f8..4c2de2300b2 100644 --- a/mindspore/ccsrc/ps/server/server.cc +++ b/mindspore/ccsrc/ps/server/server.cc @@ -18,6 +18,9 @@ #include #include #include +#ifdef ENABLE_ARMOUR +#include "armour/secure_protocol/secret_sharing.h" +#endif #include "ps/server/round.h" #include "ps/server/model_store.h" #include "ps/server/iteration.h" @@ -39,7 +42,7 @@ void SignalHandler(int signal) { } void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector &rounds_config, - const FuncGraphPtr &func_graph, size_t executor_threshold) { + const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) { MS_EXCEPTION_IF_NULL(func_graph); func_graph_ = func_graph; @@ -48,6 +51,7 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s return; } rounds_config_ = rounds_config; + cipher_config_ = cipher_config; use_tcp_ = use_tcp; use_http_ = use_http; @@ -80,6 +84,7 @@ void Server::Run() { RegisterCommCallbacks(); StartCommunicator(); InitExecutor(); + InitCipher(); RegisterRoundKernel(); MS_LOG(INFO) << "Server started successfully."; safemode_ = false; @@ -177,6 +182,42 @@ void Server::InitIteration() { iteration_->AddRound(round); } + cipher_initial_client_cnt_ = rounds_config_[0].threshold_count; + cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0; + cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio; + cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count; + cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count; + cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold; + + MS_LOG(INFO) << "Initializing cipher:"; + MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_ + << " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_ + << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_; + MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_ + << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_ + << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_; + +#ifdef ENABLE_ARMOUR + std::shared_ptr exchange_keys_round = + std::make_shared("exchangeKeys", false, 3000, true, cipher_exchange_secrets_cnt_); + iteration_->AddRound(exchange_keys_round); + std::shared_ptr get_keys_round = + std::make_shared("getKeys", false, 3000, true, cipher_exchange_secrets_cnt_); + iteration_->AddRound(get_keys_round); + std::shared_ptr share_secrets_round = + std::make_shared("shareSecrets", false, 3000, true, cipher_share_secrets_cnt_); + iteration_->AddRound(share_secrets_round); + std::shared_ptr get_secrets_round = + std::make_shared("getSecrets", false, 3000, true, cipher_share_secrets_cnt_); + iteration_->AddRound(get_secrets_round); + std::shared_ptr get_clientlist_round = + std::make_shared("getClientList", false, 3000, true, cipher_get_clientlist_cnt_); + iteration_->AddRound(get_clientlist_round); + std::shared_ptr reconstruct_secrets_round = + std::make_shared("reconstructSecrets", false, 3000, true, cipher_reconstruct_secrets_up_cnt_); + iteration_->AddRound(reconstruct_secrets_round); +#endif + // 2.Initialize all the rounds. TimeOutCb time_out_cb = std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2); @@ -186,6 +227,37 @@ void Server::InitIteration() { return; } +void Server::InitCipher() { +#ifdef ENABLE_ARMOUR + cipher_init_ = &armour::CipherInit::GetInstance(); + + int cipher_t = cipher_reconstruct_secrets_down_cnt_; + unsigned char cipher_p[SECRET_MAX_LEN] = {0}; + int cipher_g = 1; + unsigned char cipher_prime[PRIME_MAX_LEN] = {0}; + + mpz_t prim; + mpz_init(prim); + mindspore::armour::GetRandomPrime(prim); + mindspore::armour::PrintBigInteger(prim, 16); + + size_t len_cipher_prime; + mpz_export((unsigned char *)cipher_prime, &len_cipher_prime, sizeof(unsigned char), 1, 0, 0, prim); + mindspore::armour::CipherPublicPara param; + param.g = cipher_g; + param.t = cipher_t; + memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN); + memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN); + // param.dp_delta = dp_delta; + // param.dp_eps = dp_eps; + // param.dp_norm_clip = dp_norm_clip; + param.encrypt_type = kNotEncryptType; // PSContext::instance()->encrypt_type; + cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_, + cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_, + cipher_reconstruct_secrets_up_cnt_); +#endif +} + void Server::RegisterCommCallbacks() { // The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register // rounds' callbacks. diff --git a/mindspore/ccsrc/ps/server/server.h b/mindspore/ccsrc/ps/server/server.h index b028a73dac8..06595a83bce 100644 --- a/mindspore/ccsrc/ps/server/server.h +++ b/mindspore/ccsrc/ps/server/server.h @@ -26,6 +26,9 @@ #include "ps/server/common.h" #include "ps/server/executor.h" #include "ps/server/iteration.h" +#ifdef ENABLE_ARMOUR +#include "armour/cipher/cipher_init.h" +#endif namespace mindspore { namespace ps { @@ -39,7 +42,7 @@ class Server { } void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector &rounds_config, - const FuncGraphPtr &func_graph, size_t executor_threshold); + const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold); // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be // blocked until the server is finalized. @@ -92,6 +95,9 @@ class Server { // Initialize executor according to the server mode. void InitExecutor(); + // Initialize cipher according to the public param. + void InitCipher(); + // Create round kernels and bind these kernels with corresponding Round. void RegisterRoundKernel(); @@ -120,6 +126,7 @@ class Server { // The configure of all rounds. std::vector rounds_config_; + CipherConfig cipher_config_; // The graph passed by the frontend without backend optimizing. FuncGraphPtr func_graph_; @@ -147,10 +154,25 @@ class Server { std::atomic_bool safemode_; // Variables set by ps context. +#ifdef ENABLE_ARMOUR + armour::CipherInit *cipher_init_; +#endif std::string scheduler_ip_; uint16_t scheduler_port_; uint32_t server_num_; uint32_t worker_num_; + uint16_t fl_server_port_; + size_t start_fl_job_cnt_; + size_t update_model_cnt_; + size_t cipher_initial_client_cnt_; + size_t cipher_exchange_secrets_cnt_; + size_t cipher_share_secrets_cnt_; + size_t cipher_get_clientlist_cnt_; + size_t cipher_reconstruct_secrets_up_cnt_; + size_t cipher_reconstruct_secrets_down_cnt_; + + float percent_for_update_model_; + float percent_for_get_model_; }; } // namespace server } // namespace ps diff --git a/mindspore/context.py b/mindspore/context.py index 01ee18f42f4..2484c7578b9 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -836,14 +836,15 @@ def set_fl_context(**kwargs): fl_server_port (int): The http port of the federated learning server. Normally for each server this should be set to the same value. Default: 6668. enable_fl_client (bool): Whether this process is federated learning client. Default: False. - start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0. + start_fl_job_threshold (int): The threshold count of startFLJob. Default: 1. start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000. - update_model_ratio (float): The ratio for computing the threshold count of updateModel - which will be multiplied by start_fl_job_threshold. - Must be between 0 and 1.0.Default: 1.0. + share_secrets_ratio (float): The ratio for computing the threshold count of share secrets. Default: 1.0. + update_model_ratio (float): The ratio for computing the threshold count of updateModel. Default: 1.0. + get_model_ratio (float): The ratio for computing the threshold count of get model. Default: 1.0. + reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 0. update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000. - fl_name (str): The federated learning job name. Default: ''. - fl_iteration_num (int): Iteration number of federeated learning, + fl_name (string): The federated learning job name. Default: ''. + fl_iteration_num (int): Iteration number of federated learning, which is the number of interactions between client and server. Default: 20. client_epoch_num (int): Client training epoch number. Default: 25. client_batch_size (int): Client training data batch size. Default: 32. diff --git a/mindspore/core/gvar/log_adapter_common.cc b/mindspore/core/gvar/log_adapter_common.cc index ac4b84e2067..4526ecb25ff 100644 --- a/mindspore/core/gvar/log_adapter_common.cc +++ b/mindspore/core/gvar/log_adapter_common.cc @@ -45,6 +45,7 @@ static const std::vector sub_module_names = { "PROFILER", // SM_PROFILER "PS", // SM_PS "LITE", // SM_LITE + "ARMOUR", // SM_ARMOUR "HCCL_ADPT", // SM_HCCL_ADPT "MINDQUANTUM", // SM_MINDQUANTUM "RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index 8d48211145c..d2bf4f875ce 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -132,6 +132,7 @@ enum SubModuleId : int { SM_PROFILER, // profiler SM_PS, // Parameter Server SM_LITE, // LITE + SM_ARMOUR, // ARMOUR SM_HCCL_ADPT, // Hccl Adapter SM_MINDQUANTUM, // MindQuantum SM_RUNTIME_FRAMEWORK, // Runtime framework diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index a590f057091..dda7fbd84c4 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -57,6 +57,9 @@ _set_ps_context_func_map = { "start_fl_job_time_window": ps_context().set_start_fl_job_time_window, "update_model_ratio": ps_context().set_update_model_ratio, "update_model_time_window": ps_context().set_update_model_time_window, + "share_secrets_ratio": ps_context().set_share_secrets_ratio, + "get_model_ratio": ps_context().set_get_model_ratio, + "reconstruct_secrets_threshhold": ps_context().set_reconstruct_secrets_threshhold, "fl_name": ps_context().set_fl_name, "fl_iteration_num": ps_context().set_fl_iteration_num, "client_epoch_num": ps_context().set_client_epoch_num, diff --git a/mindspore/schema/cipher.fbs b/mindspore/schema/cipher.fbs index bd27cfb7aa2..6b5059c8ce8 100644 --- a/mindspore/schema/cipher.fbs +++ b/mindspore/schema/cipher.fbs @@ -24,7 +24,7 @@ table CipherPublicParams { dp_eps:float; dp_delta:float; dp_norm_clip:float; - encrypt_type:int; + encrypt_type:string; } table ClientPublicKeys { diff --git a/tests/st/fl/mobile/run_mobile_server.py b/tests/st/fl/mobile/run_mobile_server.py index 93d6517ec90..8e062fa6bf0 100644 --- a/tests/st/fl/mobile/run_mobile_server.py +++ b/tests/st/fl/mobile/run_mobile_server.py @@ -28,6 +28,9 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1) parser.add_argument("--start_fl_job_time_window", type=int, default=3000) parser.add_argument("--update_model_ratio", type=float, default=1.0) parser.add_argument("--update_model_time_window", type=int, default=3000) +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--get_model_ratio", type=float, default=1.0) +parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0) parser.add_argument("--fl_name", type=str, default="Lenet") parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) @@ -48,6 +51,9 @@ if __name__ == "__main__": start_fl_job_time_window = args.start_fl_job_time_window update_model_ratio = args.update_model_ratio update_model_time_window = args.update_model_time_window + share_secrets_ratio = args.share_secrets_ratio + get_model_ratio = args.get_model_ratio + reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold fl_name = args.fl_name fl_iteration_num = args.fl_iteration_num client_epoch_num = args.client_epoch_num @@ -78,6 +84,9 @@ if __name__ == "__main__": cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window) cmd_server += " --update_model_ratio=" + str(update_model_ratio) cmd_server += " --update_model_time_window=" + str(update_model_time_window) + cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio) + cmd_server += " --get_model_ratio=" + str(get_model_ratio) + cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold) cmd_server += " --fl_name=" + fl_name cmd_server += " --fl_iteration_num=" + str(fl_iteration_num) cmd_server += " --client_epoch_num=" + str(client_epoch_num) diff --git a/tests/st/fl/mobile/test_mobile_lenet.py b/tests/st/fl/mobile/test_mobile_lenet.py index c208b245027..2c8c31abefb 100644 --- a/tests/st/fl/mobile/test_mobile_lenet.py +++ b/tests/st/fl/mobile/test_mobile_lenet.py @@ -36,6 +36,9 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1) parser.add_argument("--start_fl_job_time_window", type=int, default=3000) parser.add_argument("--update_model_ratio", type=float, default=1.0) parser.add_argument("--update_model_time_window", type=int, default=3000) +parser.add_argument("--share_secrets_ratio", type=float, default=1.0) +parser.add_argument("--get_model_ratio", type=float, default=1.0) +parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0) parser.add_argument("--fl_name", type=str, default="Lenet") parser.add_argument("--fl_iteration_num", type=int, default=25) parser.add_argument("--client_epoch_num", type=int, default=20) @@ -56,6 +59,9 @@ start_fl_job_threshold = args.start_fl_job_threshold start_fl_job_time_window = args.start_fl_job_time_window update_model_ratio = args.update_model_ratio update_model_time_window = args.update_model_time_window +share_secrets_ratio = args.share_secrets_ratio +get_model_ratio = args.get_model_ratio +reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold fl_name = args.fl_name fl_iteration_num = args.fl_iteration_num client_epoch_num = args.client_epoch_num @@ -76,6 +82,9 @@ ctx = { "start_fl_job_time_window": start_fl_job_time_window, "update_model_ratio": update_model_ratio, "update_model_time_window": update_model_time_window, + "share_secrets_ratio": share_secrets_ratio, + "get_model_ratio": get_model_ratio, + "reconstruct_secrets_threshhold": reconstruct_secrets_threshhold, "fl_name": fl_name, "fl_iteration_num": fl_iteration_num, "client_epoch_num": client_epoch_num, diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index cc48d88e814..f90d2930f27 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -159,6 +159,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/ps/*.cc" "../../../mindspore/ccsrc/profiler/device/common/*.cc" "../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c" + "../../../mindspore/ccsrc/armour/*.cc" ) list(REMOVE_ITEM MINDSPORE_SRC_LIST