cipher calling code
This commit is contained in:
parent
3bf35c3ae7
commit
5af0c890a1
|
@ -227,6 +227,11 @@ set(SUB_COMP
|
|||
common debug pybind_api utils vm profiler ps
|
||||
)
|
||||
|
||||
if(ENABLE_CPU AND NOT WIN32)
|
||||
add_compile_definitions(ENABLE_ARMOUR)
|
||||
list(APPEND SUB_COMP "armour")
|
||||
endif()
|
||||
|
||||
foreach(_comp ${SUB_COMP})
|
||||
add_subdirectory(${_comp})
|
||||
string(REPLACE "/" "_" sub ${_comp})
|
||||
|
|
|
@ -1,14 +1,5 @@
|
|||
file(GLOB_RECURSE ARMOUR_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
|
||||
if(NOT ENABLE_CPU OR WIN32)
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_init.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_keys.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_meta_storage.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_reconstruct.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_shares.cc")
|
||||
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_unmask.cc")
|
||||
endif()
|
||||
|
||||
set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema")
|
||||
set(FBS_FILES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CIPHER_INIT_H
|
||||
#define MINDSPORE_CIPHER_INIT_H
|
||||
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_INIT_H
|
||||
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_INIT_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
@ -74,4 +74,4 @@ class CipherInit {
|
|||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CIPHER_COMMON_H
|
||||
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_COMMON_H
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CIPHER_KEYS_H
|
||||
#define MINDSPORE_CIPHER_KEYS_H
|
||||
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H
|
||||
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
@ -68,4 +68,4 @@ class CipherKeys {
|
|||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CIPHER_KEYS_H
|
||||
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H
|
||||
|
|
|
@ -14,16 +14,18 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CIPHER_META_STORAGE_H
|
||||
#define MINDSPORE_CIPHER_META_STORAGE_H
|
||||
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_META_STORAGE_H
|
||||
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_META_STORAGE_H
|
||||
|
||||
#include <gmp.h>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#ifndef _WIN32
|
||||
#include <gmp.h>
|
||||
#endif
|
||||
#include "proto/ps.pb.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "armour/secure_protocol/secret_sharing.h"
|
||||
|
@ -52,7 +54,7 @@ struct CipherPublicPara {
|
|||
float dp_eps;
|
||||
float dp_delta;
|
||||
float dp_norm_clip;
|
||||
int encrypt_type;
|
||||
string encrypt_type;
|
||||
};
|
||||
|
||||
class CipherMetaStorage {
|
||||
|
@ -89,4 +91,4 @@ class CipherMetaStorage {
|
|||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CIPHER_META_STORAGE_H
|
||||
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_META_STORAGE_H
|
||||
|
|
|
@ -29,6 +29,10 @@ bool CipherReconStruct::CombineMask(
|
|||
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
|
||||
const std::vector<string> &client_list) {
|
||||
bool retcode = true;
|
||||
#ifdef _WIN32
|
||||
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
|
||||
retcode = false;
|
||||
#else
|
||||
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
|
||||
// define flag_share: judge we need b or s
|
||||
bool flag_share = true;
|
||||
|
@ -45,7 +49,6 @@ bool CipherReconStruct::CombineMask(
|
|||
mpz_init(prime);
|
||||
auto publicparam_ = CipherInit::GetInstance().GetPublicParams();
|
||||
mpz_import(prime, PRIME_MAX_LEN, 1, 1, 0, 0, publicparam_->prime);
|
||||
|
||||
if (iter->second.size() >= cipher_init_->secrets_minnums_) { // combine private key seed.
|
||||
MS_LOG(INFO) << "start assign secrets shares to public shares ";
|
||||
for (int i = 0; i < static_cast<int>(cipher_init_->secrets_minnums_); ++i) {
|
||||
|
@ -89,7 +92,7 @@ bool CipherReconStruct::CombineMask(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return retcode;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CIPHER_RECONSTRUCT_H
|
||||
#define MINDSPORE_CIPHER_RECONSTRUCT_H
|
||||
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_RECONSTRUCT_H
|
||||
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_RECONSTRUCT_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
@ -84,4 +84,4 @@ class CipherReconStruct {
|
|||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CIPHER_KEYS_H
|
||||
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CIPHER_SHARES_H
|
||||
#define MINDSPORE_CIPHER_SHARES_H
|
||||
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_SHARES_H
|
||||
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_SHARES_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
@ -65,4 +65,4 @@ class CipherShares {
|
|||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CIPHER_KEYS_H
|
||||
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_SHARES_H
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CIPHER_UNMASK_H
|
||||
#define MINDSPORE_CIPHER_UNMASK_H
|
||||
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_UNMASK_H
|
||||
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_UNMASK_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
@ -42,4 +42,4 @@ class CipherUnmask {
|
|||
} // namespace armour
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CIPHER_KEYS_H
|
||||
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_UNMASK_H
|
||||
|
|
|
@ -672,14 +672,21 @@ bool StartServerAction(const ResourcePtr &res) {
|
|||
{"pullWeight"},
|
||||
{"pushWeight", false, 3000, true, server_num, true}};
|
||||
|
||||
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
|
||||
float get_model_ratio = ps::PSContext::instance()->get_model_ratio();
|
||||
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
|
||||
|
||||
ps::server::CipherConfig cipher_config = {share_secrets_ratio, get_model_ratio, reconstruct_secrets_threshhold};
|
||||
|
||||
size_t executor_threshold = 0;
|
||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||
executor_threshold = update_model_threshold;
|
||||
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, func_graph,
|
||||
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
|
||||
executor_threshold);
|
||||
} else if (server_mode_ == ps::kServerModePS) {
|
||||
executor_threshold = worker_num;
|
||||
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, func_graph, executor_threshold);
|
||||
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
|
||||
executor_threshold);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
|
||||
return false;
|
||||
|
|
|
@ -354,6 +354,11 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
"Set threshold count ratio for updateModel round.")
|
||||
.def("set_update_model_time_window", &PSContext::set_update_model_time_window,
|
||||
"Set time window for updateModel round.")
|
||||
.def("set_share_secrets_ratio", &PSContext::set_share_secrets_ratio,
|
||||
"Set threshold count ratio for share secrets round.")
|
||||
.def("set_get_model_ratio", &PSContext::set_get_model_ratio, "Set threshold count ratio for get model round.")
|
||||
.def("set_reconstruct_secrets_threshhold", &PSContext::set_reconstruct_secrets_threshhold,
|
||||
"Set threshold count for reconstruct secrets round.")
|
||||
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
|
||||
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
|
||||
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
|
||||
|
|
|
@ -105,6 +105,10 @@ message ClientNoises {
|
|||
OneClientNoises one_client_noises = 1;
|
||||
}
|
||||
|
||||
message PrimeList {
|
||||
repeated bytes prime = 1;
|
||||
}
|
||||
|
||||
message PairClientKeys {
|
||||
string fl_id = 1;
|
||||
KeysPb client_keys = 2;
|
||||
|
@ -128,6 +132,10 @@ message KeysPb {
|
|||
repeated bytes key = 1;
|
||||
}
|
||||
|
||||
message Prime {
|
||||
bytes prime = 1;
|
||||
}
|
||||
|
||||
message PBMetadata {
|
||||
oneof value {
|
||||
DeviceMeta device_meta = 1;
|
||||
|
@ -146,6 +154,9 @@ message PBMetadata {
|
|||
|
||||
OneClientNoises one_client_noises = 10;
|
||||
ClientNoises client_noises = 11;
|
||||
|
||||
Prime prime = 12;
|
||||
PrimeList prime_list = 13;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -294,6 +294,20 @@ void PSContext::set_update_model_time_window(uint64_t update_model_time_window)
|
|||
|
||||
uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
|
||||
|
||||
void PSContext::set_share_secrets_ratio(float share_secrets_ratio) { share_secrets_ratio_ = share_secrets_ratio; }
|
||||
|
||||
float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
|
||||
|
||||
void PSContext::set_get_model_ratio(float get_model_ratio) { get_model_ratio_ = get_model_ratio; }
|
||||
|
||||
float PSContext::get_model_ratio() const { return get_model_ratio_; }
|
||||
|
||||
void PSContext::set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold) {
|
||||
reconstruct_secrets_threshhold_ = reconstruct_secrets_threshhold;
|
||||
}
|
||||
|
||||
uint64_t PSContext::reconstruct_secrets_threshhold() const { return reconstruct_secrets_threshhold_; }
|
||||
|
||||
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
|
||||
|
||||
const std::string &PSContext::fl_name() const { return fl_name_; }
|
||||
|
|
|
@ -35,6 +35,9 @@ constexpr char kEnvRoleOfServer[] = "MS_SERVER";
|
|||
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
|
||||
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
|
||||
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
|
||||
constexpr char kDPEncryptType[] = "DPEncrypt";
|
||||
constexpr char kPWEncryptType[] = "PWEncrypt";
|
||||
constexpr char kNotEncryptType[] = "NotEncrypt";
|
||||
|
||||
// Use binary data to represent federated learning server's context so that we can judge which round resets the
|
||||
// iteration. From right to left, each bit stands for:
|
||||
|
@ -125,6 +128,15 @@ class PSContext {
|
|||
void set_update_model_time_window(uint64_t update_model_time_window);
|
||||
uint64_t update_model_time_window() const;
|
||||
|
||||
void set_share_secrets_ratio(float share_secrets_ratio);
|
||||
float share_secrets_ratio() const;
|
||||
|
||||
void set_get_model_ratio(float get_model_ratio);
|
||||
float get_model_ratio() const;
|
||||
|
||||
void set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold);
|
||||
uint64_t reconstruct_secrets_threshhold() const;
|
||||
|
||||
void set_fl_name(const std::string &fl_name);
|
||||
const std::string &fl_name() const;
|
||||
|
||||
|
@ -173,6 +185,9 @@ class PSContext {
|
|||
start_fl_job_time_window_(3000),
|
||||
update_model_ratio_(1.0),
|
||||
update_model_time_window_(3000),
|
||||
share_secrets_ratio_(1.0),
|
||||
get_model_ratio_(1.0),
|
||||
reconstruct_secrets_threshhold_(2000),
|
||||
fl_iteration_num_(20),
|
||||
client_epoch_num_(25),
|
||||
client_batch_size_(32),
|
||||
|
@ -222,6 +237,15 @@ class PSContext {
|
|||
// The time window of updateModel round in millisecond.
|
||||
uint64_t update_model_time_window_;
|
||||
|
||||
// Share model threshold is a certain ratio of share secrets threshold which is set as share_secrets_ratio_.
|
||||
float share_secrets_ratio_;
|
||||
|
||||
// Get model threshold is a certain ratio of get model threshold which is set as get_model_ratio_.
|
||||
float get_model_ratio_;
|
||||
|
||||
// The threshold count of reconstruct secrets round. Used in federated learning for now.
|
||||
uint64_t reconstruct_secrets_threshhold_;
|
||||
|
||||
// Iteration number of federeated learning, which is the number of interactions between client and server.
|
||||
uint64_t fl_iteration_num_;
|
||||
|
||||
|
|
|
@ -59,6 +59,12 @@ struct RoundConfig {
|
|||
bool server_num_as_threshold = false;
|
||||
};
|
||||
|
||||
struct CipherConfig {
|
||||
float share_secrets_ratio = 1.0;
|
||||
float get_model_ratio = 1.0;
|
||||
size_t reconstruct_secrets_threshhold = 0;
|
||||
};
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
using mindspore::kernel::CPUKernel;
|
||||
|
@ -175,9 +181,17 @@ constexpr auto kCtxDeviceMetas = "device_metas";
|
|||
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
|
||||
constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
|
||||
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
|
||||
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
|
||||
constexpr auto kCtxUpdateModelThld = "update_model_threshold";
|
||||
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
|
||||
constexpr auto kCtxClientsKeys = "clients_keys";
|
||||
constexpr auto kCtxClientNoises = "clients_noises";
|
||||
constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";
|
||||
constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares";
|
||||
constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list";
|
||||
constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
|
||||
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
|
||||
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
|
||||
constexpr auto kCtxCipherPrimer = "cipher_primer";
|
||||
|
||||
// This macro the current timestamp in milliseconds.
|
||||
#define CURRENT_TIME_MILLI \
|
||||
|
|
|
@ -238,6 +238,63 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
|
|||
} else if (meta.has_update_model_threshold()) {
|
||||
auto update_model_threshold = metadata_[name].mutable_update_model_threshold();
|
||||
*update_model_threshold = meta.update_model_threshold();
|
||||
} else if (meta.has_prime()) {
|
||||
auto prime_list = metadata_[name].mutable_prime_list();
|
||||
auto &prime_id = meta.prime().prime();
|
||||
if (prime_list->prime_size() == 0) {
|
||||
prime_list->add_prime(prime_id);
|
||||
}
|
||||
} else if (meta.has_pair_client_keys()) {
|
||||
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
|
||||
auto &fl_id = meta.pair_client_keys().fl_id();
|
||||
auto &client_keys = meta.pair_client_keys().client_keys();
|
||||
// Check whether the new item already exists.
|
||||
bool add_flag = true;
|
||||
for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); iter++) {
|
||||
if (fl_id == iter->first) {
|
||||
add_flag = false;
|
||||
MS_LOG(ERROR) << "Leader server updating value for " << name
|
||||
<< " failed: The Protobuffer of this value already exists.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (add_flag) {
|
||||
client_keys_map[fl_id] = client_keys;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else if (meta.has_pair_client_shares()) {
|
||||
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
|
||||
auto &fl_id = meta.pair_client_shares().fl_id();
|
||||
auto &client_shares = meta.pair_client_shares().client_shares();
|
||||
// google::protobuf::Map< std::string, mindspore::ps::core::SharesPb >::const_iterator iter;
|
||||
// Check whether the new item already exists.
|
||||
bool add_flag = true;
|
||||
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); iter++) {
|
||||
if (fl_id == iter->first) {
|
||||
add_flag = false;
|
||||
MS_LOG(ERROR) << "Leader server updating value for " << name
|
||||
<< " failed: The Protobuffer of this value already exists.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (add_flag) {
|
||||
client_shares_map[fl_id] = client_shares;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else if (meta.has_one_client_noises()) {
|
||||
auto &client_noises = *metadata_[name].mutable_client_noises();
|
||||
if (client_noises.has_one_client_noises()) {
|
||||
MS_LOG(ERROR) << "Leader server updating value for " << name
|
||||
<< " failed: The Protobuffer of this value already exists.";
|
||||
client_noises.Clear();
|
||||
}
|
||||
client_noises.mutable_one_client_noises()->MergeFrom(meta.one_client_noises());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Leader server updating value for " << name
|
||||
<< " failed: The Protobuffer of this value is not defined.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -268,10 +268,14 @@ std::map<std::string, AddressPtr> Executor::GetModel() {
|
|||
return model;
|
||||
}
|
||||
|
||||
// bool Executor::Unmask() {
|
||||
// auto model = GetModel();
|
||||
// return mindarmour::CipherMgr::GetInstance().UnMask(model);
|
||||
// }
|
||||
bool Executor::Unmask() {
|
||||
#ifdef ENABLE_ARMOUR
|
||||
auto model = GetModel();
|
||||
return cipher_unmask_.UnMask(model);
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
const std::vector<std::string> &Executor::param_names() const { return param_names_; }
|
||||
|
||||
|
|
|
@ -26,6 +26,9 @@
|
|||
#include <condition_variable>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/parameter_aggregator.h"
|
||||
#ifdef ENABLE_ARMOUR
|
||||
#include "armour/cipher/cipher_unmask.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -88,6 +91,7 @@ class Executor {
|
|||
bool initialized() const;
|
||||
|
||||
const std::vector<std::string> ¶m_names() const;
|
||||
bool Unmask();
|
||||
|
||||
private:
|
||||
Executor() {}
|
||||
|
@ -117,6 +121,9 @@ class Executor {
|
|||
// Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can
|
||||
// acquire lock before calling its method.
|
||||
std::map<std::string, std::mutex> parameter_mutex_;
|
||||
#ifdef ENABLE_ARMOUR
|
||||
armour::CipherUnmask cipher_unmask_;
|
||||
#endif
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
|
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/client_list_kernel.h"
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "schema/cipher_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void ClientListKernel::InitKernel(size_t) {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
cipher_init_ = &armour::CipherInit::GetInstance();
|
||||
}
|
||||
|
||||
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
|
||||
std::shared_ptr<server::FBBuilder> fbb) {
|
||||
bool response = false;
|
||||
std::vector<string> client_list;
|
||||
std::string fl_id = get_clients_req->fl_id()->str();
|
||||
int32_t iter_client = (size_t)get_clients_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "ClientListKernel iteration invalid. servertime is " << iter_num;
|
||||
MS_LOG(ERROR) << "ClientListKernel iteration invalid. clienttime is " << iter_client;
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
||||
uint64_t update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
||||
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
||||
client_list.push_back(client_list_pb.fl_id(i));
|
||||
}
|
||||
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) { // client in client_list.
|
||||
if (static_cast<uint64_t>(client_list_pb.fl_id_size()) >= update_model_client_num) {
|
||||
MS_LOG(INFO) << "send clients_list succeed!";
|
||||
MS_LOG(INFO) << "UpdateModel client list: ";
|
||||
for (size_t i = 0; i < client_list.size(); ++i) {
|
||||
MS_LOG(INFO) << " fl_id : " << client_list[i];
|
||||
}
|
||||
MS_LOG(INFO) << "update_model_client_num: " << update_model_client_num;
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
response = true;
|
||||
} else {
|
||||
MS_LOG(INFO) << "The server is not ready. update_model_client_need_num: " << update_model_client_num;
|
||||
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
|
||||
/*for (size_t i = 0; i < std::min(client_list.size(), size_t(2)); ++i) {
|
||||
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
|
||||
}
|
||||
for (size_t i = client_list.size() - size_t(1); i > std::max(client_list.size() - size_t(2), size_t(0));
|
||||
--i) {
|
||||
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
|
||||
}*/
|
||||
int count_tmp = 0;
|
||||
for (size_t i = 0; i < cipher_init_->get_model_num_need_; ++i) {
|
||||
size_t j = 0;
|
||||
for (; j < client_list.size(); ++j) {
|
||||
if (("f" + std::to_string(i)) == client_list[j]) break;
|
||||
}
|
||||
if (j >= client_list.size()) {
|
||||
count_tmp++;
|
||||
MS_LOG(INFO) << " no client_list fl_id : " << i;
|
||||
if (count_tmp > 3) break;
|
||||
}
|
||||
}
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
}
|
||||
}
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str());
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "update_model_client_num is zero.";
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_num is zero.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
}
|
||||
}
|
||||
return response;
|
||||
}
|
||||
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
std::vector<string> client_list;
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ClientListKernel needs 1 input,but got " << inputs.size();
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel input num not match", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ClientListKernel needs 1 output,but got " << outputs.size();
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel output num not match", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "ClientListKernel num is enough", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
|
||||
|
||||
if (get_clients_req == nullptr || fbb == nullptr) {
|
||||
MS_LOG(ERROR) << "GetClientList is nullptr or ClientListRsp builder is nullptr.";
|
||||
BuildClientListRsp(fbb, schema::ResponseCode_RequestError,
|
||||
"GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
response = DealClient(iter_num, get_clients_req, fbb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
|
||||
return response;
|
||||
} // namespace ps
|
||||
|
||||
bool ClientListKernel::Reset() {
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "Get Client list kernel reset!";
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
StopTimer();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> client_list_resp_builder,
|
||||
const schema::ResponseCode retcode, const string &reason,
|
||||
std::vector<std::string> clients, const string &next_req_time,
|
||||
const int iteration) {
|
||||
auto rsp_reason = client_list_resp_builder->CreateString(reason);
|
||||
auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time);
|
||||
if (clients.size() > 0) {
|
||||
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
|
||||
for (auto client : clients) {
|
||||
auto client_fb = client_list_resp_builder->CreateString(client);
|
||||
clients_vector.push_back(client_fb);
|
||||
}
|
||||
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
|
||||
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
|
||||
rsp_builder.add_retcode(retcode);
|
||||
rsp_builder.add_reason(rsp_reason);
|
||||
rsp_builder.add_clients(clients_fb);
|
||||
rsp_builder.add_iteration(iteration);
|
||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
||||
client_list_resp_builder->Finish(rsp_exchange_keys);
|
||||
} else {
|
||||
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
|
||||
rsp_builder.add_retcode(retcode);
|
||||
rsp_builder.add_reason(rsp_reason);
|
||||
rsp_builder.add_iteration(iteration);
|
||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
||||
client_list_resp_builder->Finish(rsp_exchange_keys);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(getClientList, ClientListKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
#include "armour/cipher/cipher_init.h"
|
||||
#include "ps/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class ClientListKernel : public RoundKernel {
|
||||
public:
|
||||
ClientListKernel() = default;
|
||||
~ClientListKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
void BuildClientListRsp(std::shared_ptr<server::FBBuilder> client_list_resp_builder,
|
||||
const schema::ResponseCode retcode, const string &reason, std::vector<std::string> clients,
|
||||
const string &next_req_time, const int iteration);
|
||||
|
||||
private:
|
||||
armour::CipherInit *cipher_init_;
|
||||
bool DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
|
||||
std::shared_ptr<server::FBBuilder> fbb);
|
||||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/exchange_keys_kernel.h"
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void ExchangeKeysKernel::InitKernel(size_t) {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
|
||||
cipher_key_ = &armour::CipherKeys::GetInstance();
|
||||
}
|
||||
|
||||
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is "
|
||||
<< total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 input,but got " << inputs.size();
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel input num not match",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 output,but got " << outputs.size();
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel output num not match",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for ExchangeKeysKernel is enough.";
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||
"Current amount for ExchangeKeysKernel is enough.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::RequestExchangeKeys *exchange_keys_req =
|
||||
flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
|
||||
int32_t iter_client = (size_t)exchange_keys_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "ExchangeKeysKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
response =
|
||||
cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
|
||||
return response;
|
||||
}
|
||||
|
||||
bool ExchangeKeysKernel::Reset() {
|
||||
MS_LOG(INFO) << "exchange keys kernel reset, ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
StopTimer();
|
||||
return true;
|
||||
}
|
||||
REG_ROUND_KERNEL(exchangeKeys, ExchangeKeysKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
#include "ps/server/executor.h"
|
||||
#include "armour/cipher/cipher_keys.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class ExchangeKeysKernel : public RoundKernel {
|
||||
public:
|
||||
ExchangeKeysKernel() = default;
|
||||
~ExchangeKeysKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
armour::CipherKeys *cipher_key_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/get_keys_kernel.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void GetKeysKernel::InitKernel(size_t) {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
|
||||
cipher_key_ = &armour::CipherKeys::GetInstance();
|
||||
}
|
||||
|
||||
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetKeysKernel allowed Duration Is "
|
||||
<< total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "GetKeysKernel needs 1 input,but got " << inputs.size();
|
||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "GetKeysKernel needs 1 output,but got " << outputs.size();
|
||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
|
||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
|
||||
int32_t iter_client = (size_t)get_exchange_keys_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||
} else {
|
||||
response =
|
||||
cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, get_exchange_keys_req->fl_id()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
|
||||
return response;
|
||||
}
|
||||
|
||||
bool GetKeysKernel::Reset() {
|
||||
MS_LOG(INFO) << "get keys kernel reset! ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
cipher_key_->ClearKeys();
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
StopTimer();
|
||||
return true;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(getKeys, GetKeysKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
#include "ps/server/executor.h"
|
||||
#include "armour/cipher/cipher_keys.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class GetKeysKernel : public RoundKernel {
|
||||
public:
|
||||
GetKeysKernel() = default;
|
||||
~GetKeysKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
armour::CipherKeys *cipher_key_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/get_secrets_kernel.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "armour/cipher/cipher_shares.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void GetSecretsKernel::InitKernel(size_t) {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
|
||||
cipher_share_ = &armour::CipherShares::GetInstance();
|
||||
}
|
||||
|
||||
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool response = false;
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is "
|
||||
<< total_duration;
|
||||
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "GetSecretsKernel needs 1 input,but got " << inputs.size();
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "GetSecretsKernel needs 1 output,but got " << outputs.size();
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
|
||||
int32_t iter_client = (size_t)get_secrets_req->iteration();
|
||||
if (iter_num != (size_t)iter_client) {
|
||||
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
|
||||
} else {
|
||||
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
|
||||
return response;
|
||||
}
|
||||
|
||||
bool GetSecretsKernel::Reset() {
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "GetSecretsKernel reset!";
|
||||
cipher_share_->ClearShareSecrets();
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
StopTimer();
|
||||
return true;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(getSecrets, GetSecretsKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
#include "armour/cipher/cipher_shares.h"
|
||||
#include "ps/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class GetSecretsKernel : public RoundKernel {
|
||||
public:
|
||||
GetSecretsKernel() = default;
|
||||
~GetSecretsKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
armour::CipherShares *cipher_share_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
|
|
@ -0,0 +1,165 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/reconstruct_secrets_kernel.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
|
||||
MS_LOG(INFO) << "start FinishIteration";
|
||||
FinishIteration();
|
||||
MS_LOG(INFO) << "end FinishIteration";
|
||||
return;
|
||||
};
|
||||
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) { return; };
|
||||
name_unmask_ = "UnMaskKernel";
|
||||
MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : "
|
||||
<< LocalMetaStore::GetInstance().curr_iter_num();
|
||||
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, PSContext::instance()->initial_server_num(),
|
||||
{first_cnt_handler, last_cnt_handler});
|
||||
}
|
||||
|
||||
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
bool response = false;
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
// MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ReconstructSecretsKernel total duration is "
|
||||
<< total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
||||
"ReconstructSecretsKernel input num not match.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 output, but got " << outputs.size();
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
||||
"ReconstructSecretsKernel output num not match.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::SendReconstructSecret *reconstruct_secret_req =
|
||||
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
|
||||
|
||||
// get client list from memory server.
|
||||
std::vector<string> client_list;
|
||||
uint64_t update_model_client_num = 0;
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
||||
update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "update_model_client_num is zero.";
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
||||
"update_model_client_num is zero.", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return false;
|
||||
}
|
||||
const PBMetadata client_list_pb_out =
|
||||
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||
|
||||
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
|
||||
int client_list_actual_size = client_list_pb.fl_id_size();
|
||||
if (client_list_actual_size < 0) {
|
||||
client_list_actual_size = 0;
|
||||
}
|
||||
if (static_cast<uint64_t>(client_list_actual_size) < update_model_client_num) {
|
||||
MS_LOG(INFO) << "ReconstructSecretsKernel : client list is not ready " << inputs.size();
|
||||
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SucNotReady,
|
||||
"ReconstructSecretsKernel : client list is not ready", iter_num,
|
||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
return response;
|
||||
}
|
||||
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
|
||||
client_list.push_back(client_list_pb.fl_id(i));
|
||||
}
|
||||
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||
reconstruct_secret_req, fbb, client_list);
|
||||
if (response) {
|
||||
// MS_LOG(INFO) << "start ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
|
||||
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
|
||||
// MS_LOG(INFO) << "end ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
|
||||
}
|
||||
}
|
||||
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "reconstruct_secrets_kernel success time is : " << duration;
|
||||
return response;
|
||||
}
|
||||
|
||||
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
if (true) { // todo: PSContext::instance()->encrypt_type == PWEncrypt {
|
||||
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "start unmask";
|
||||
while (!Executor::GetInstance().Unmask()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(5));
|
||||
}
|
||||
MS_LOG(INFO) << "end unmask";
|
||||
std::string worker_id = std::to_string(DistributedCountService::GetInstance().local_rank());
|
||||
DistributedCountService::GetInstance().Count(name_unmask_, worker_id);
|
||||
}
|
||||
}
|
||||
|
||||
bool ReconstructSecretsKernel::Reset() {
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
MS_LOG(INFO) << "reconstruct secrets kernel reset!";
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
DistributedCountService::GetInstance().ResetCounter(name_unmask_);
|
||||
StopTimer();
|
||||
cipher_reconstruct_.ClearReconstructSecrets();
|
||||
return true;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(reconstructSecrets, ReconstructSecretsKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
#include "armour/cipher/cipher_reconstruct.h"
|
||||
#include "ps/server/executor.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class ReconstructSecretsKernel : public RoundKernel {
|
||||
public:
|
||||
ReconstructSecretsKernel() = default;
|
||||
~ReconstructSecretsKernel() override = default;
|
||||
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
|
||||
|
||||
private:
|
||||
std::string name_unmask_;
|
||||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
armour::CipherReconStruct cipher_reconstruct_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/server/kernel/round/share_secrets_kernel.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
void ShareSecretsKernel::InitKernel(size_t) {
|
||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
}
|
||||
|
||||
executor_ = &Executor::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(executor_);
|
||||
if (!executor_->initialized()) {
|
||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
||||
return;
|
||||
}
|
||||
cipher_share_ = &armour::CipherShares::GetInstance();
|
||||
}
|
||||
|
||||
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool response = false;
|
||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is "
|
||||
<< total_duration;
|
||||
clock_t start_time = clock();
|
||||
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 input,but got " << inputs.size();
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError, "ShareSecretsKernel input num not match",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else if (outputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 output,but got " << outputs.size();
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError,
|
||||
"ShareSecretsKernel output num not match",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||
"Current amount for ShareSecretsKernel is enough.",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
void *req_data = inputs[0]->addr;
|
||||
const schema::RequestShareSecrets *share_secrets_req =
|
||||
flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
|
||||
size_t iter_client = (size_t)share_secrets_req->iteration();
|
||||
if (iter_num != iter_client) {
|
||||
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||
<< ". client request iteration is " << iter_client;
|
||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
|
||||
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||
} else {
|
||||
response =
|
||||
cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
|
||||
if (response) {
|
||||
DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||
clock_t end_time = clock();
|
||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
||||
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
|
||||
return response;
|
||||
}
|
||||
|
||||
bool ShareSecretsKernel::Reset() {
|
||||
MS_LOG(INFO) << "share_secrets_kernel reset! ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||
StopTimer();
|
||||
return true;
|
||||
}
|
||||
|
||||
REG_ROUND_KERNEL(shareSecrets, ShareSecretsKernel)
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
||||
|
||||
#include <vector>
|
||||
#include "ps/server/common.h"
|
||||
#include "ps/server/executor.h"
|
||||
#include "ps/server/kernel/round/round_kernel.h"
|
||||
#include "ps/server/kernel/round/round_kernel_factory.h"
|
||||
#include "armour/cipher/cipher_shares.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace server {
|
||||
namespace kernel {
|
||||
class ShareSecretsKernel : public RoundKernel {
|
||||
public:
|
||||
ShareSecretsKernel() = default;
|
||||
~ShareSecretsKernel() override = default;
|
||||
void InitKernel(size_t required_cnt) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Reset() override;
|
||||
|
||||
private:
|
||||
Executor *executor_;
|
||||
size_t iteration_time_window_;
|
||||
armour::CipherShares *cipher_share_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
|
|
@ -21,6 +21,9 @@
|
|||
#include <vector>
|
||||
#include "ps/server/model_store.h"
|
||||
#include "ps/server/iteration.h"
|
||||
#ifdef ENABLE_ARMOUR
|
||||
#include "armour/cipher/cipher_init.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -192,6 +195,21 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode());
|
||||
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
|
||||
|
||||
#ifdef ENABLE_ARMOUR
|
||||
auto *param = armour::CipherInit::GetInstance().GetPublicParams();
|
||||
auto prime = fbb->CreateVector(param->prime, PRIME_MAX_LEN);
|
||||
auto p = fbb->CreateVector(param->p, SECRET_MAX_LEN);
|
||||
int32_t t = param->t;
|
||||
int32_t g = param->g;
|
||||
float dp_eps = param->dp_eps;
|
||||
float dp_delta = param->dp_delta;
|
||||
float dp_norm_clip = param->dp_norm_clip;
|
||||
auto encrypt_type = fbb->CreateString(param->encrypt_type);
|
||||
|
||||
auto cipher_public_params =
|
||||
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
|
||||
#endif
|
||||
|
||||
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
|
||||
fl_plan_builder.add_fl_name(fbs_fl_name);
|
||||
fl_plan_builder.add_server_mode(fbs_server_mode);
|
||||
|
@ -199,6 +217,9 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
|||
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
|
||||
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
|
||||
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate());
|
||||
#ifdef ENABLE_ARMOUR
|
||||
fl_plan_builder.add_cipher(cipher_public_params);
|
||||
#endif
|
||||
auto fbs_fl_plan = fl_plan_builder.Finish();
|
||||
|
||||
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||
|
|
|
@ -18,6 +18,9 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <csignal>
|
||||
#ifdef ENABLE_ARMOUR
|
||||
#include "armour/secure_protocol/secret_sharing.h"
|
||||
#endif
|
||||
#include "ps/server/round.h"
|
||||
#include "ps/server/model_store.h"
|
||||
#include "ps/server/iteration.h"
|
||||
|
@ -39,7 +42,7 @@ void SignalHandler(int signal) {
|
|||
}
|
||||
|
||||
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||
const FuncGraphPtr &func_graph, size_t executor_threshold) {
|
||||
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph_ = func_graph;
|
||||
|
||||
|
@ -48,6 +51,7 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
|
|||
return;
|
||||
}
|
||||
rounds_config_ = rounds_config;
|
||||
cipher_config_ = cipher_config;
|
||||
|
||||
use_tcp_ = use_tcp;
|
||||
use_http_ = use_http;
|
||||
|
@ -80,6 +84,7 @@ void Server::Run() {
|
|||
RegisterCommCallbacks();
|
||||
StartCommunicator();
|
||||
InitExecutor();
|
||||
InitCipher();
|
||||
RegisterRoundKernel();
|
||||
MS_LOG(INFO) << "Server started successfully.";
|
||||
safemode_ = false;
|
||||
|
@ -177,6 +182,42 @@ void Server::InitIteration() {
|
|||
iteration_->AddRound(round);
|
||||
}
|
||||
|
||||
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
|
||||
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
|
||||
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
|
||||
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
|
||||
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
|
||||
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold;
|
||||
|
||||
MS_LOG(INFO) << "Initializing cipher:";
|
||||
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
|
||||
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
|
||||
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
|
||||
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
|
||||
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
|
||||
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
|
||||
|
||||
#ifdef ENABLE_ARMOUR
|
||||
std::shared_ptr<Round> exchange_keys_round =
|
||||
std::make_shared<Round>("exchangeKeys", false, 3000, true, cipher_exchange_secrets_cnt_);
|
||||
iteration_->AddRound(exchange_keys_round);
|
||||
std::shared_ptr<Round> get_keys_round =
|
||||
std::make_shared<Round>("getKeys", false, 3000, true, cipher_exchange_secrets_cnt_);
|
||||
iteration_->AddRound(get_keys_round);
|
||||
std::shared_ptr<Round> share_secrets_round =
|
||||
std::make_shared<Round>("shareSecrets", false, 3000, true, cipher_share_secrets_cnt_);
|
||||
iteration_->AddRound(share_secrets_round);
|
||||
std::shared_ptr<Round> get_secrets_round =
|
||||
std::make_shared<Round>("getSecrets", false, 3000, true, cipher_share_secrets_cnt_);
|
||||
iteration_->AddRound(get_secrets_round);
|
||||
std::shared_ptr<Round> get_clientlist_round =
|
||||
std::make_shared<Round>("getClientList", false, 3000, true, cipher_get_clientlist_cnt_);
|
||||
iteration_->AddRound(get_clientlist_round);
|
||||
std::shared_ptr<Round> reconstruct_secrets_round =
|
||||
std::make_shared<Round>("reconstructSecrets", false, 3000, true, cipher_reconstruct_secrets_up_cnt_);
|
||||
iteration_->AddRound(reconstruct_secrets_round);
|
||||
#endif
|
||||
|
||||
// 2.Initialize all the rounds.
|
||||
TimeOutCb time_out_cb =
|
||||
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
|
||||
|
@ -186,6 +227,37 @@ void Server::InitIteration() {
|
|||
return;
|
||||
}
|
||||
|
||||
void Server::InitCipher() {
|
||||
#ifdef ENABLE_ARMOUR
|
||||
cipher_init_ = &armour::CipherInit::GetInstance();
|
||||
|
||||
int cipher_t = cipher_reconstruct_secrets_down_cnt_;
|
||||
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
|
||||
int cipher_g = 1;
|
||||
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
|
||||
|
||||
mpz_t prim;
|
||||
mpz_init(prim);
|
||||
mindspore::armour::GetRandomPrime(prim);
|
||||
mindspore::armour::PrintBigInteger(prim, 16);
|
||||
|
||||
size_t len_cipher_prime;
|
||||
mpz_export((unsigned char *)cipher_prime, &len_cipher_prime, sizeof(unsigned char), 1, 0, 0, prim);
|
||||
mindspore::armour::CipherPublicPara param;
|
||||
param.g = cipher_g;
|
||||
param.t = cipher_t;
|
||||
memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
|
||||
memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
|
||||
// param.dp_delta = dp_delta;
|
||||
// param.dp_eps = dp_eps;
|
||||
// param.dp_norm_clip = dp_norm_clip;
|
||||
param.encrypt_type = kNotEncryptType; // PSContext::instance()->encrypt_type;
|
||||
cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_,
|
||||
cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
|
||||
cipher_reconstruct_secrets_up_cnt_);
|
||||
#endif
|
||||
}
|
||||
|
||||
void Server::RegisterCommCallbacks() {
|
||||
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
|
||||
// rounds' callbacks.
|
||||
|
|
|
@ -26,6 +26,9 @@
|
|||
#include "ps/server/common.h"
|
||||
#include "ps/server/executor.h"
|
||||
#include "ps/server/iteration.h"
|
||||
#ifdef ENABLE_ARMOUR
|
||||
#include "armour/cipher/cipher_init.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
|
@ -39,7 +42,7 @@ class Server {
|
|||
}
|
||||
|
||||
void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||
const FuncGraphPtr &func_graph, size_t executor_threshold);
|
||||
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold);
|
||||
|
||||
// According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
|
||||
// blocked until the server is finalized.
|
||||
|
@ -92,6 +95,9 @@ class Server {
|
|||
// Initialize executor according to the server mode.
|
||||
void InitExecutor();
|
||||
|
||||
// Initialize cipher according to the public param.
|
||||
void InitCipher();
|
||||
|
||||
// Create round kernels and bind these kernels with corresponding Round.
|
||||
void RegisterRoundKernel();
|
||||
|
||||
|
@ -120,6 +126,7 @@ class Server {
|
|||
|
||||
// The configure of all rounds.
|
||||
std::vector<RoundConfig> rounds_config_;
|
||||
CipherConfig cipher_config_;
|
||||
|
||||
// The graph passed by the frontend without backend optimizing.
|
||||
FuncGraphPtr func_graph_;
|
||||
|
@ -147,10 +154,25 @@ class Server {
|
|||
std::atomic_bool safemode_;
|
||||
|
||||
// Variables set by ps context.
|
||||
#ifdef ENABLE_ARMOUR
|
||||
armour::CipherInit *cipher_init_;
|
||||
#endif
|
||||
std::string scheduler_ip_;
|
||||
uint16_t scheduler_port_;
|
||||
uint32_t server_num_;
|
||||
uint32_t worker_num_;
|
||||
uint16_t fl_server_port_;
|
||||
size_t start_fl_job_cnt_;
|
||||
size_t update_model_cnt_;
|
||||
size_t cipher_initial_client_cnt_;
|
||||
size_t cipher_exchange_secrets_cnt_;
|
||||
size_t cipher_share_secrets_cnt_;
|
||||
size_t cipher_get_clientlist_cnt_;
|
||||
size_t cipher_reconstruct_secrets_up_cnt_;
|
||||
size_t cipher_reconstruct_secrets_down_cnt_;
|
||||
|
||||
float percent_for_update_model_;
|
||||
float percent_for_get_model_;
|
||||
};
|
||||
} // namespace server
|
||||
} // namespace ps
|
||||
|
|
|
@ -836,14 +836,15 @@ def set_fl_context(**kwargs):
|
|||
fl_server_port (int): The http port of the federated learning server.
|
||||
Normally for each server this should be set to the same value. Default: 6668.
|
||||
enable_fl_client (bool): Whether this process is federated learning client. Default: False.
|
||||
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0.
|
||||
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 1.
|
||||
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
|
||||
update_model_ratio (float): The ratio for computing the threshold count of updateModel
|
||||
which will be multiplied by start_fl_job_threshold.
|
||||
Must be between 0 and 1.0.Default: 1.0.
|
||||
share_secrets_ratio (float): The ratio for computing the threshold count of share secrets. Default: 1.0.
|
||||
update_model_ratio (float): The ratio for computing the threshold count of updateModel. Default: 1.0.
|
||||
get_model_ratio (float): The ratio for computing the threshold count of get model. Default: 1.0.
|
||||
reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 0.
|
||||
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
|
||||
fl_name (str): The federated learning job name. Default: ''.
|
||||
fl_iteration_num (int): Iteration number of federeated learning,
|
||||
fl_name (string): The federated learning job name. Default: ''.
|
||||
fl_iteration_num (int): Iteration number of federated learning,
|
||||
which is the number of interactions between client and server. Default: 20.
|
||||
client_epoch_num (int): Client training epoch number. Default: 25.
|
||||
client_batch_size (int): Client training data batch size. Default: 32.
|
||||
|
|
|
@ -45,6 +45,7 @@ static const std::vector<std::string> sub_module_names = {
|
|||
"PROFILER", // SM_PROFILER
|
||||
"PS", // SM_PS
|
||||
"LITE", // SM_LITE
|
||||
"ARMOUR", // SM_ARMOUR
|
||||
"HCCL_ADPT", // SM_HCCL_ADPT
|
||||
"MINDQUANTUM", // SM_MINDQUANTUM
|
||||
"RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK
|
||||
|
|
|
@ -132,6 +132,7 @@ enum SubModuleId : int {
|
|||
SM_PROFILER, // profiler
|
||||
SM_PS, // Parameter Server
|
||||
SM_LITE, // LITE
|
||||
SM_ARMOUR, // ARMOUR
|
||||
SM_HCCL_ADPT, // Hccl Adapter
|
||||
SM_MINDQUANTUM, // MindQuantum
|
||||
SM_RUNTIME_FRAMEWORK, // Runtime framework
|
||||
|
|
|
@ -57,6 +57,9 @@ _set_ps_context_func_map = {
|
|||
"start_fl_job_time_window": ps_context().set_start_fl_job_time_window,
|
||||
"update_model_ratio": ps_context().set_update_model_ratio,
|
||||
"update_model_time_window": ps_context().set_update_model_time_window,
|
||||
"share_secrets_ratio": ps_context().set_share_secrets_ratio,
|
||||
"get_model_ratio": ps_context().set_get_model_ratio,
|
||||
"reconstruct_secrets_threshhold": ps_context().set_reconstruct_secrets_threshhold,
|
||||
"fl_name": ps_context().set_fl_name,
|
||||
"fl_iteration_num": ps_context().set_fl_iteration_num,
|
||||
"client_epoch_num": ps_context().set_client_epoch_num,
|
||||
|
|
|
@ -24,7 +24,7 @@ table CipherPublicParams {
|
|||
dp_eps:float;
|
||||
dp_delta:float;
|
||||
dp_norm_clip:float;
|
||||
encrypt_type:int;
|
||||
encrypt_type:string;
|
||||
}
|
||||
|
||||
table ClientPublicKeys {
|
||||
|
|
|
@ -28,6 +28,9 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
|||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--get_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
|
@ -48,6 +51,9 @@ if __name__ == "__main__":
|
|||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
get_model_ratio = args.get_model_ratio
|
||||
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
|
@ -78,6 +84,9 @@ if __name__ == "__main__":
|
|||
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
|
||||
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
|
||||
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
|
||||
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
|
||||
cmd_server += " --get_model_ratio=" + str(get_model_ratio)
|
||||
cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold)
|
||||
cmd_server += " --fl_name=" + fl_name
|
||||
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
|
||||
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
|
||||
|
|
|
@ -36,6 +36,9 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1)
|
|||
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
|
||||
parser.add_argument("--update_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--update_model_time_window", type=int, default=3000)
|
||||
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--get_model_ratio", type=float, default=1.0)
|
||||
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0)
|
||||
parser.add_argument("--fl_name", type=str, default="Lenet")
|
||||
parser.add_argument("--fl_iteration_num", type=int, default=25)
|
||||
parser.add_argument("--client_epoch_num", type=int, default=20)
|
||||
|
@ -56,6 +59,9 @@ start_fl_job_threshold = args.start_fl_job_threshold
|
|||
start_fl_job_time_window = args.start_fl_job_time_window
|
||||
update_model_ratio = args.update_model_ratio
|
||||
update_model_time_window = args.update_model_time_window
|
||||
share_secrets_ratio = args.share_secrets_ratio
|
||||
get_model_ratio = args.get_model_ratio
|
||||
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
|
||||
fl_name = args.fl_name
|
||||
fl_iteration_num = args.fl_iteration_num
|
||||
client_epoch_num = args.client_epoch_num
|
||||
|
@ -76,6 +82,9 @@ ctx = {
|
|||
"start_fl_job_time_window": start_fl_job_time_window,
|
||||
"update_model_ratio": update_model_ratio,
|
||||
"update_model_time_window": update_model_time_window,
|
||||
"share_secrets_ratio": share_secrets_ratio,
|
||||
"get_model_ratio": get_model_ratio,
|
||||
"reconstruct_secrets_threshhold": reconstruct_secrets_threshhold,
|
||||
"fl_name": fl_name,
|
||||
"fl_iteration_num": fl_iteration_num,
|
||||
"client_epoch_num": client_epoch_num,
|
||||
|
|
|
@ -159,6 +159,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/ps/*.cc"
|
||||
"../../../mindspore/ccsrc/profiler/device/common/*.cc"
|
||||
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c"
|
||||
"../../../mindspore/ccsrc/armour/*.cc"
|
||||
)
|
||||
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST
|
||||
|
|
Loading…
Reference in New Issue