cipher calling code

This commit is contained in:
ql_12345 2021-06-17 09:49:15 +08:00 committed by yangyuan
parent 3bf35c3ae7
commit 5af0c890a1
41 changed files with 1410 additions and 47 deletions

View File

@ -227,6 +227,11 @@ set(SUB_COMP
common debug pybind_api utils vm profiler ps
)
if(ENABLE_CPU AND NOT WIN32)
add_compile_definitions(ENABLE_ARMOUR)
list(APPEND SUB_COMP "armour")
endif()
foreach(_comp ${SUB_COMP})
add_subdirectory(${_comp})
string(REPLACE "/" "_" sub ${_comp})

View File

@ -1,14 +1,5 @@
file(GLOB_RECURSE ARMOUR_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_init.cc")
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_keys.cc")
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_meta_storage.cc")
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_reconstruct.cc")
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_shares.cc")
list(REMOVE_ITEM ARMOUR_FILES "cipher/cipher_unmask.cc")
endif()
set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema")
set(FBS_FILES
${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CIPHER_INIT_H
#define MINDSPORE_CIPHER_INIT_H
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_INIT_H
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_INIT_H
#include <vector>
#include <string>
@ -74,4 +74,4 @@ class CipherInit {
} // namespace armour
} // namespace mindspore
#endif // MINDSPORE_CIPHER_COMMON_H
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_COMMON_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CIPHER_KEYS_H
#define MINDSPORE_CIPHER_KEYS_H
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H
#include <vector>
#include <string>
@ -68,4 +68,4 @@ class CipherKeys {
} // namespace armour
} // namespace mindspore
#endif // MINDSPORE_CIPHER_KEYS_H
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H

View File

@ -14,16 +14,18 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CIPHER_META_STORAGE_H
#define MINDSPORE_CIPHER_META_STORAGE_H
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_META_STORAGE_H
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_META_STORAGE_H
#include <gmp.h>
#include <utility>
#include <algorithm>
#include <map>
#include <vector>
#include <string>
#include <memory>
#ifndef _WIN32
#include <gmp.h>
#endif
#include "proto/ps.pb.h"
#include "utils/log_adapter.h"
#include "armour/secure_protocol/secret_sharing.h"
@ -52,7 +54,7 @@ struct CipherPublicPara {
float dp_eps;
float dp_delta;
float dp_norm_clip;
int encrypt_type;
string encrypt_type;
};
class CipherMetaStorage {
@ -89,4 +91,4 @@ class CipherMetaStorage {
} // namespace armour
} // namespace mindspore
#endif // MINDSPORE_CIPHER_META_STORAGE_H
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_META_STORAGE_H

View File

@ -29,6 +29,10 @@ bool CipherReconStruct::CombineMask(
const std::map<std::string, std::vector<clientshare_str>> &reconstruct_secret_list,
const std::vector<string> &client_list) {
bool retcode = true;
#ifdef _WIN32
MS_LOG(ERROR) << "Unsupported feature in Windows platform.";
retcode = false;
#else
for (auto iter = reconstruct_secret_list.begin(); iter != reconstruct_secret_list.end(); ++iter) {
// define flag_share: judge we need b or s
bool flag_share = true;
@ -45,7 +49,6 @@ bool CipherReconStruct::CombineMask(
mpz_init(prime);
auto publicparam_ = CipherInit::GetInstance().GetPublicParams();
mpz_import(prime, PRIME_MAX_LEN, 1, 1, 0, 0, publicparam_->prime);
if (iter->second.size() >= cipher_init_->secrets_minnums_) { // combine private key seed.
MS_LOG(INFO) << "start assign secrets shares to public shares ";
for (int i = 0; i < static_cast<int>(cipher_init_->secrets_minnums_); ++i) {
@ -89,7 +92,7 @@ bool CipherReconStruct::CombineMask(
}
}
}
#endif
return retcode;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CIPHER_RECONSTRUCT_H
#define MINDSPORE_CIPHER_RECONSTRUCT_H
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_RECONSTRUCT_H
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_RECONSTRUCT_H
#include <vector>
#include <string>
@ -84,4 +84,4 @@ class CipherReconStruct {
} // namespace armour
} // namespace mindspore
#endif // MINDSPORE_CIPHER_KEYS_H
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_KEYS_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CIPHER_SHARES_H
#define MINDSPORE_CIPHER_SHARES_H
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_SHARES_H
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_SHARES_H
#include <vector>
#include <string>
@ -65,4 +65,4 @@ class CipherShares {
} // namespace armour
} // namespace mindspore
#endif // MINDSPORE_CIPHER_KEYS_H
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_SHARES_H

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CIPHER_UNMASK_H
#define MINDSPORE_CIPHER_UNMASK_H
#ifndef MINDSPORE_CCSRC_ARMOUR_CIPHER_UNMASK_H
#define MINDSPORE_CCSRC_ARMOUR_CIPHER_UNMASK_H
#include <vector>
#include <string>
@ -42,4 +42,4 @@ class CipherUnmask {
} // namespace armour
} // namespace mindspore
#endif // MINDSPORE_CIPHER_KEYS_H
#endif // MINDSPORE_CCSRC_ARMOUR_CIPHER_UNMASK_H

View File

@ -672,14 +672,21 @@ bool StartServerAction(const ResourcePtr &res) {
{"pullWeight"},
{"pushWeight", false, 3000, true, server_num, true}};
float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
float get_model_ratio = ps::PSContext::instance()->get_model_ratio();
size_t reconstruct_secrets_threshhold = ps::PSContext::instance()->reconstruct_secrets_threshhold();
ps::server::CipherConfig cipher_config = {share_secrets_ratio, get_model_ratio, reconstruct_secrets_threshhold};
size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
executor_threshold = update_model_threshold;
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, func_graph,
ps::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
executor_threshold);
} else if (server_mode_ == ps::kServerModePS) {
executor_threshold = worker_num;
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, func_graph, executor_threshold);
ps::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
executor_threshold);
} else {
MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
return false;

View File

@ -354,6 +354,11 @@ PYBIND11_MODULE(_c_expression, m) {
"Set threshold count ratio for updateModel round.")
.def("set_update_model_time_window", &PSContext::set_update_model_time_window,
"Set time window for updateModel round.")
.def("set_share_secrets_ratio", &PSContext::set_share_secrets_ratio,
"Set threshold count ratio for share secrets round.")
.def("set_get_model_ratio", &PSContext::set_get_model_ratio, "Set threshold count ratio for get model round.")
.def("set_reconstruct_secrets_threshhold", &PSContext::set_reconstruct_secrets_threshhold,
"Set threshold count for reconstruct secrets round.")
.def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
.def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
.def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")

View File

@ -105,6 +105,10 @@ message ClientNoises {
OneClientNoises one_client_noises = 1;
}
message PrimeList {
repeated bytes prime = 1;
}
message PairClientKeys {
string fl_id = 1;
KeysPb client_keys = 2;
@ -128,6 +132,10 @@ message KeysPb {
repeated bytes key = 1;
}
message Prime {
bytes prime = 1;
}
message PBMetadata {
oneof value {
DeviceMeta device_meta = 1;
@ -146,6 +154,9 @@ message PBMetadata {
OneClientNoises one_client_noises = 10;
ClientNoises client_noises = 11;
Prime prime = 12;
PrimeList prime_list = 13;
}
}

View File

@ -294,6 +294,20 @@ void PSContext::set_update_model_time_window(uint64_t update_model_time_window)
uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
void PSContext::set_share_secrets_ratio(float share_secrets_ratio) { share_secrets_ratio_ = share_secrets_ratio; }
float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
void PSContext::set_get_model_ratio(float get_model_ratio) { get_model_ratio_ = get_model_ratio; }
float PSContext::get_model_ratio() const { return get_model_ratio_; }
void PSContext::set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold) {
reconstruct_secrets_threshhold_ = reconstruct_secrets_threshhold;
}
uint64_t PSContext::reconstruct_secrets_threshhold() const { return reconstruct_secrets_threshhold_; }
void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
const std::string &PSContext::fl_name() const { return fl_name_; }

View File

@ -35,6 +35,9 @@ constexpr char kEnvRoleOfServer[] = "MS_SERVER";
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
constexpr char kDPEncryptType[] = "DPEncrypt";
constexpr char kPWEncryptType[] = "PWEncrypt";
constexpr char kNotEncryptType[] = "NotEncrypt";
// Use binary data to represent federated learning server's context so that we can judge which round resets the
// iteration. From right to left, each bit stands for:
@ -125,6 +128,15 @@ class PSContext {
void set_update_model_time_window(uint64_t update_model_time_window);
uint64_t update_model_time_window() const;
void set_share_secrets_ratio(float share_secrets_ratio);
float share_secrets_ratio() const;
void set_get_model_ratio(float get_model_ratio);
float get_model_ratio() const;
void set_reconstruct_secrets_threshhold(uint64_t reconstruct_secrets_threshhold);
uint64_t reconstruct_secrets_threshhold() const;
void set_fl_name(const std::string &fl_name);
const std::string &fl_name() const;
@ -173,6 +185,9 @@ class PSContext {
start_fl_job_time_window_(3000),
update_model_ratio_(1.0),
update_model_time_window_(3000),
share_secrets_ratio_(1.0),
get_model_ratio_(1.0),
reconstruct_secrets_threshhold_(2000),
fl_iteration_num_(20),
client_epoch_num_(25),
client_batch_size_(32),
@ -222,6 +237,15 @@ class PSContext {
// The time window of updateModel round in millisecond.
uint64_t update_model_time_window_;
// Share model threshold is a certain ratio of share secrets threshold which is set as share_secrets_ratio_.
float share_secrets_ratio_;
// Get model threshold is a certain ratio of get model threshold which is set as get_model_ratio_.
float get_model_ratio_;
// The threshold count of reconstruct secrets round. Used in federated learning for now.
uint64_t reconstruct_secrets_threshhold_;
// Iteration number of federeated learning, which is the number of interactions between client and server.
uint64_t fl_iteration_num_;

View File

@ -59,6 +59,12 @@ struct RoundConfig {
bool server_num_as_threshold = false;
};
struct CipherConfig {
float share_secrets_ratio = 1.0;
float get_model_ratio = 1.0;
size_t reconstruct_secrets_threshhold = 0;
};
using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr;
using mindspore::kernel::CPUKernel;
@ -175,9 +181,17 @@ constexpr auto kCtxDeviceMetas = "device_metas";
constexpr auto kCtxTotalTimeoutDuration = "total_timeout_duration";
constexpr auto kCtxIterationNextRequestTimestamp = "iteration_next_request_timestamp";
constexpr auto kCtxUpdateModelClientList = "update_model_client_list";
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
constexpr auto kCtxUpdateModelThld = "update_model_threshold";
constexpr auto kCtxUpdateModelClientNum = "update_model_client_num";
constexpr auto kCtxClientsKeys = "clients_keys";
constexpr auto kCtxClientNoises = "clients_noises";
constexpr auto kCtxClientsEncryptedShares = "clients_encrypted_shares";
constexpr auto kCtxClientsReconstructShares = "clients_restruct_shares";
constexpr auto kCtxShareSecretsClientList = "share_secrets_client_list";
constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
constexpr auto kCtxCipherPrimer = "cipher_primer";
// This macro the current timestamp in milliseconds.
#define CURRENT_TIME_MILLI \

View File

@ -238,6 +238,63 @@ bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const P
} else if (meta.has_update_model_threshold()) {
auto update_model_threshold = metadata_[name].mutable_update_model_threshold();
*update_model_threshold = meta.update_model_threshold();
} else if (meta.has_prime()) {
auto prime_list = metadata_[name].mutable_prime_list();
auto &prime_id = meta.prime().prime();
if (prime_list->prime_size() == 0) {
prime_list->add_prime(prime_id);
}
} else if (meta.has_pair_client_keys()) {
auto &client_keys_map = *metadata_[name].mutable_client_keys()->mutable_client_keys();
auto &fl_id = meta.pair_client_keys().fl_id();
auto &client_keys = meta.pair_client_keys().client_keys();
// Check whether the new item already exists.
bool add_flag = true;
for (auto iter = client_keys_map.begin(); iter != client_keys_map.end(); iter++) {
if (fl_id == iter->first) {
add_flag = false;
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
break;
}
}
if (add_flag) {
client_keys_map[fl_id] = client_keys;
} else {
return false;
}
} else if (meta.has_pair_client_shares()) {
auto &client_shares_map = *metadata_[name].mutable_client_shares()->mutable_client_secret_shares();
auto &fl_id = meta.pair_client_shares().fl_id();
auto &client_shares = meta.pair_client_shares().client_shares();
// google::protobuf::Map< std::string, mindspore::ps::core::SharesPb >::const_iterator iter;
// Check whether the new item already exists.
bool add_flag = true;
for (auto iter = client_shares_map.begin(); iter != client_shares_map.end(); iter++) {
if (fl_id == iter->first) {
add_flag = false;
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
break;
}
}
if (add_flag) {
client_shares_map[fl_id] = client_shares;
} else {
return false;
}
} else if (meta.has_one_client_noises()) {
auto &client_noises = *metadata_[name].mutable_client_noises();
if (client_noises.has_one_client_noises()) {
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value already exists.";
client_noises.Clear();
}
client_noises.mutable_one_client_noises()->MergeFrom(meta.one_client_noises());
} else {
MS_LOG(ERROR) << "Leader server updating value for " << name
<< " failed: The Protobuffer of this value is not defined.";
return false;
}
return true;
}

View File

@ -268,10 +268,14 @@ std::map<std::string, AddressPtr> Executor::GetModel() {
return model;
}
// bool Executor::Unmask() {
// auto model = GetModel();
// return mindarmour::CipherMgr::GetInstance().UnMask(model);
// }
bool Executor::Unmask() {
#ifdef ENABLE_ARMOUR
auto model = GetModel();
return cipher_unmask_.UnMask(model);
#else
return false;
#endif
}
const std::vector<std::string> &Executor::param_names() const { return param_names_; }

View File

@ -26,6 +26,9 @@
#include <condition_variable>
#include "ps/server/common.h"
#include "ps/server/parameter_aggregator.h"
#ifdef ENABLE_ARMOUR
#include "armour/cipher/cipher_unmask.h"
#endif
namespace mindspore {
namespace ps {
@ -88,6 +91,7 @@ class Executor {
bool initialized() const;
const std::vector<std::string> &param_names() const;
bool Unmask();
private:
Executor() {}
@ -117,6 +121,9 @@ class Executor {
// Because ParameterAggregator is not threadsafe, we have to create mutex for each ParameterAggregator so we can
// acquire lock before calling its method.
std::map<std::string, std::mutex> parameter_mutex_;
#ifdef ENABLE_ARMOUR
armour::CipherUnmask cipher_unmask_;
#endif
};
} // namespace server
} // namespace ps

View File

@ -0,0 +1,199 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/client_list_kernel.h"
#include <utility>
#include <string>
#include <vector>
#include <memory>
#include "schema/cipher_generated.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void ClientListKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_init_ = &armour::CipherInit::GetInstance();
}
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
std::shared_ptr<server::FBBuilder> fbb) {
bool response = false;
std::vector<string> client_list;
std::string fl_id = get_clients_req->fl_id()->str();
int32_t iter_client = (size_t)get_clients_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "ClientListKernel iteration invalid. servertime is " << iter_num;
MS_LOG(ERROR) << "ClientListKernel iteration invalid. clienttime is " << iter_client;
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
uint64_t update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
PBMetadata client_list_pb_out = DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
}
if (find(client_list.begin(), client_list.end(), fl_id) != client_list.end()) { // client in client_list.
if (static_cast<uint64_t>(client_list_pb.fl_id_size()) >= update_model_client_num) {
MS_LOG(INFO) << "send clients_list succeed!";
MS_LOG(INFO) << "UpdateModel client list: ";
for (size_t i = 0; i < client_list.size(); ++i) {
MS_LOG(INFO) << " fl_id : " << client_list[i];
}
MS_LOG(INFO) << "update_model_client_num: " << update_model_client_num;
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
response = true;
} else {
MS_LOG(INFO) << "The server is not ready. update_model_client_need_num: " << update_model_client_num;
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
/*for (size_t i = 0; i < std::min(client_list.size(), size_t(2)); ++i) {
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
}
for (size_t i = client_list.size() - size_t(1); i > std::max(client_list.size() - size_t(2), size_t(0));
--i) {
MS_LOG(INFO) << " client_list fl_id : " << client_list[i];
}*/
int count_tmp = 0;
for (size_t i = 0; i < cipher_init_->get_model_num_need_; ++i) {
size_t j = 0;
for (; j < client_list.size(); ++j) {
if (("f" + std::to_string(i)) == client_list[j]) break;
}
if (j >= client_list.size()) {
count_tmp++;
MS_LOG(INFO) << " no client_list fl_id : " << i;
if (count_tmp > 3) break;
}
}
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
}
}
if (response) {
DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str());
}
} else {
MS_LOG(ERROR) << "update_model_client_num is zero.";
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_num is zero.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
}
}
return response;
}
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
clock_t start_time = clock();
std::vector<string> client_list;
if (inputs.size() != 1) {
MS_LOG(ERROR) << "ClientListKernel needs 1 input,but got " << inputs.size();
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel input num not match", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ClientListKernel needs 1 output,but got " << outputs.size();
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "ClientListKernel output num not match", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "ClientListKernel num is enough", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
void *req_data = inputs[0]->addr;
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
if (get_clients_req == nullptr || fbb == nullptr) {
MS_LOG(ERROR) << "GetClientList is nullptr or ClientListRsp builder is nullptr.";
BuildClientListRsp(fbb, schema::ResponseCode_RequestError,
"GetClientList is nullptr or ClientListRsp builder is nullptr.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
response = DealClient(iter_num, get_clients_req, fbb);
}
}
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
return response;
} // namespace ps
bool ClientListKernel::Reset() {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Get Client list kernel reset!";
DistributedCountService::GetInstance().ResetCounter(name_);
StopTimer();
return true;
}
void ClientListKernel::BuildClientListRsp(std::shared_ptr<server::FBBuilder> client_list_resp_builder,
const schema::ResponseCode retcode, const string &reason,
std::vector<std::string> clients, const string &next_req_time,
const int iteration) {
auto rsp_reason = client_list_resp_builder->CreateString(reason);
auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time);
if (clients.size() > 0) {
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
for (auto client : clients) {
auto client_fb = client_list_resp_builder->CreateString(client);
clients_vector.push_back(client_fb);
}
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_clients(clients_fb);
rsp_builder.add_iteration(iteration);
rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys);
} else {
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
rsp_builder.add_retcode(retcode);
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_iteration(iteration);
rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys);
}
return;
}
REG_ROUND_KERNEL(getClientList, ClientListKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H
#include <string>
#include <vector>
#include <memory>
#include "ps/server/common.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
#include "armour/cipher/cipher_init.h"
#include "ps/server/executor.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class ClientListKernel : public RoundKernel {
public:
ClientListKernel() = default;
~ClientListKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
void BuildClientListRsp(std::shared_ptr<server::FBBuilder> client_list_resp_builder,
const schema::ResponseCode retcode, const string &reason, std::vector<std::string> clients,
const string &next_req_time, const int iteration);
private:
armour::CipherInit *cipher_init_;
bool DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
std::shared_ptr<server::FBBuilder> fbb);
Executor *executor_;
size_t iteration_time_window_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_CLIENT_LIST_KERNEL_H

View File

@ -0,0 +1,101 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/exchange_keys_kernel.h"
#include <vector>
#include <utility>
#include <memory>
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void ExchangeKeysKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_key_ = &armour::CipherKeys::GetInstance();
}
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 input,but got " << inputs.size();
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel input num not match",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ExchangeKeysKernel needs 1 output,but got " << outputs.size();
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_SystemError, "ExchangeKeysKernel output num not match",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ExchangeKeysKernel is enough.";
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ExchangeKeysKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
void *req_data = inputs[0]->addr;
const schema::RequestExchangeKeys *exchange_keys_req =
flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
int32_t iter_client = (size_t)exchange_keys_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "ExchangeKeysKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
response =
cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
if (response) {
DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str());
}
}
}
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
return response;
}
bool ExchangeKeysKernel::Reset() {
MS_LOG(INFO) << "exchange keys kernel reset, ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
DistributedCountService::GetInstance().ResetCounter(name_);
StopTimer();
return true;
}
REG_ROUND_KERNEL(exchangeKeys, ExchangeKeysKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H
#include <vector>
#include "ps/server/common.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
#include "ps/server/executor.h"
#include "armour/cipher/cipher_keys.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class ExchangeKeysKernel : public RoundKernel {
public:
ExchangeKeysKernel() = default;
~ExchangeKeysKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
private:
Executor *executor_;
size_t iteration_time_window_;
armour::CipherKeys *cipher_key_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_EXCHANGE_KEYS_KERNEL_H

View File

@ -0,0 +1,100 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/get_keys_kernel.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void GetKeysKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_key_ = &armour::CipherKeys::GetInstance();
}
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetKeysKernel allowed Duration Is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "GetKeysKernel needs 1 input,but got " << inputs.size();
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "GetKeysKernel needs 1 output,but got " << outputs.size();
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_SystemError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
} else {
void *req_data = inputs[0]->addr;
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
int32_t iter_client = (size_t)get_exchange_keys_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_key_->BuildGetKeys(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
} else {
response =
cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
if (response) {
DistributedCountService::GetInstance().Count(name_, get_exchange_keys_req->fl_id()->str());
}
}
}
}
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
return response;
}
bool GetKeysKernel::Reset() {
MS_LOG(INFO) << "get keys kernel reset! ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
cipher_key_->ClearKeys();
DistributedCountService::GetInstance().ResetCounter(name_);
StopTimer();
return true;
}
REG_ROUND_KERNEL(getKeys, GetKeysKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H
#include <vector>
#include "ps/server/common.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
#include "ps/server/executor.h"
#include "armour/cipher/cipher_keys.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class GetKeysKernel : public RoundKernel {
public:
GetKeysKernel() = default;
~GetKeysKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
private:
Executor *executor_;
size_t iteration_time_window_;
armour::CipherKeys *cipher_key_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_KEYS_KERNEL_H

View File

@ -0,0 +1,103 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/get_secrets_kernel.h"
#include <vector>
#include <memory>
#include <string>
#include "armour/cipher/cipher_shares.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void GetSecretsKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_share_ = &armour::CipherShares::GetInstance();
}
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
bool response = false;
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "GetSecretsKernel needs 1 input,but got " << inputs.size();
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "GetSecretsKernel needs 1 output,but got " << outputs.size();
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_SystemError, iter_num, next_timestamp, 0);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
} else {
void *req_data = inputs[0]->addr;
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
int32_t iter_client = (size_t)get_secrets_req->iteration();
if (iter_num != (size_t)iter_client) {
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, 0);
} else {
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
if (response) {
DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str());
}
}
}
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
return response;
}
bool GetSecretsKernel::Reset() {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "GetSecretsKernel reset!";
cipher_share_->ClearShareSecrets();
DistributedCountService::GetInstance().ResetCounter(name_);
StopTimer();
return true;
}
REG_ROUND_KERNEL(getSecrets, GetSecretsKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H
#include <vector>
#include "ps/server/common.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
#include "armour/cipher/cipher_shares.h"
#include "ps/server/executor.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class GetSecretsKernel : public RoundKernel {
public:
GetSecretsKernel() = default;
~GetSecretsKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
private:
Executor *executor_;
size_t iteration_time_window_;
armour::CipherShares *cipher_share_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_GET_SECRETS_KERNEL_H

View File

@ -0,0 +1,165 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/reconstruct_secrets_kernel.h"
#include <string>
#include <vector>
#include <memory>
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void ReconstructSecretsKernel::InitKernel(size_t required_cnt) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
auto last_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) {
MS_LOG(INFO) << "start FinishIteration";
FinishIteration();
MS_LOG(INFO) << "end FinishIteration";
return;
};
auto first_cnt_handler = [&](std::shared_ptr<core::MessageHandler>) { return; };
name_unmask_ = "UnMaskKernel";
MS_LOG(INFO) << "ReconstructSecretsKernel Init, ITERATION NUMBER IS : "
<< LocalMetaStore::GetInstance().curr_iter_num();
DistributedCountService::GetInstance().RegisterCounter(name_unmask_, PSContext::instance()->initial_server_num(),
{first_cnt_handler, last_cnt_handler});
}
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
// MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ReconstructSecretsKernel total duration is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
"ReconstructSecretsKernel input num not match.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 output, but got " << outputs.size();
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
"ReconstructSecretsKernel output num not match.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ReconstructSecretsKernel is enough.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
void *req_data = inputs[0]->addr;
const schema::SendReconstructSecret *reconstruct_secret_req =
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
// get client list from memory server.
std::vector<string> client_list;
uint64_t update_model_client_num = 0;
if (LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
update_model_client_num = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
} else {
MS_LOG(ERROR) << "update_model_client_num is zero.";
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SystemError,
"update_model_client_num is zero.", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return false;
}
const PBMetadata client_list_pb_out =
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &client_list_pb = client_list_pb_out.client_list();
int client_list_actual_size = client_list_pb.fl_id_size();
if (client_list_actual_size < 0) {
client_list_actual_size = 0;
}
if (static_cast<uint64_t>(client_list_actual_size) < update_model_client_num) {
MS_LOG(INFO) << "ReconstructSecretsKernel : client list is not ready " << inputs.size();
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_SucNotReady,
"ReconstructSecretsKernel : client list is not ready", iter_num,
std::to_string(CURRENT_TIME_MILLI.count()));
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return response;
}
for (int i = 0; i < client_list_pb.fl_id_size(); ++i) {
client_list.push_back(client_list_pb.fl_id(i));
}
response = cipher_reconstruct_.ReconstructSecrets(iter_num, std::to_string(CURRENT_TIME_MILLI.count()),
reconstruct_secret_req, fbb, client_list);
if (response) {
// MS_LOG(INFO) << "start ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
DistributedCountService::GetInstance().Count(name_, reconstruct_secret_req->fl_id()->str());
// MS_LOG(INFO) << "end ReconstructSecretsKernel Success. fl_id : " << reconstruct_secret_req->fl_id()->str();
}
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "reconstruct_secrets_kernel success time is : " << duration;
return response;
}
void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
if (true) { // todo: PSContext::instance()->encrypt_type == PWEncrypt {
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
MS_LOG(INFO) << "start unmask";
while (!Executor::GetInstance().Unmask()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}
MS_LOG(INFO) << "end unmask";
std::string worker_id = std::to_string(DistributedCountService::GetInstance().local_rank());
DistributedCountService::GetInstance().Count(name_unmask_, worker_id);
}
}
bool ReconstructSecretsKernel::Reset() {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "reconstruct secrets kernel reset!";
DistributedCountService::GetInstance().ResetCounter(name_);
DistributedCountService::GetInstance().ResetCounter(name_unmask_);
StopTimer();
cipher_reconstruct_.ClearReconstructSecrets();
return true;
}
REG_ROUND_KERNEL(reconstructSecrets, ReconstructSecretsKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_
#include <vector>
#include <memory>
#include <string>
#include "ps/server/common.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
#include "armour/cipher/cipher_reconstruct.h"
#include "ps/server/executor.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class ReconstructSecretsKernel : public RoundKernel {
public:
ReconstructSecretsKernel() = default;
~ReconstructSecretsKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) override;
private:
std::string name_unmask_;
Executor *executor_;
size_t iteration_time_window_;
armour::CipherReconStruct cipher_reconstruct_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_RECONSTRUCT_SECRETS_KERNEL_H_

View File

@ -0,0 +1,102 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/server/kernel/round/share_secrets_kernel.h"
#include <vector>
#include <memory>
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
void ShareSecretsKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_share_ = &armour::CipherShares::GetInstance();
}
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
bool response = false;
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1) {
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 input,but got " << inputs.size();
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError, "ShareSecretsKernel input num not match",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else if (outputs.size() != 1) {
MS_LOG(ERROR) << "ShareSecretsKernel needs 1 output,but got " << outputs.size();
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_SystemError,
"ShareSecretsKernel output num not match",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for ShareSecretsKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
void *req_data = inputs[0]->addr;
const schema::RequestShareSecrets *share_secrets_req =
flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
size_t iter_client = (size_t)share_secrets_req->iteration();
if (iter_num != iter_client) {
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, "ShareSecretsKernel iteration invalid",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
response =
cipher_share_->ShareSecrets(iter_num, share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count()));
if (response) {
DistributedCountService::GetInstance().Count(name_, share_secrets_req->fl_id()->str());
}
}
}
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
return response;
}
bool ShareSecretsKernel::Reset() {
MS_LOG(INFO) << "share_secrets_kernel reset! ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
DistributedCountService::GetInstance().ResetCounter(name_);
StopTimer();
return true;
}
REG_ROUND_KERNEL(shareSecrets, ShareSecretsKernel)
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#define MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H
#include <vector>
#include "ps/server/common.h"
#include "ps/server/executor.h"
#include "ps/server/kernel/round/round_kernel.h"
#include "ps/server/kernel/round/round_kernel_factory.h"
#include "armour/cipher/cipher_shares.h"
namespace mindspore {
namespace ps {
namespace server {
namespace kernel {
class ShareSecretsKernel : public RoundKernel {
public:
ShareSecretsKernel() = default;
~ShareSecretsKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
private:
Executor *executor_;
size_t iteration_time_window_;
armour::CipherShares *cipher_share_;
};
} // namespace kernel
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_SHARE_SECRETS_KERNEL_H

View File

@ -21,6 +21,9 @@
#include <vector>
#include "ps/server/model_store.h"
#include "ps/server/iteration.h"
#ifdef ENABLE_ARMOUR
#include "armour/cipher/cipher_init.h"
#endif
namespace mindspore {
namespace ps {
@ -192,6 +195,21 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
auto fbs_server_mode = fbb->CreateString(PSContext::instance()->server_mode());
auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name());
#ifdef ENABLE_ARMOUR
auto *param = armour::CipherInit::GetInstance().GetPublicParams();
auto prime = fbb->CreateVector(param->prime, PRIME_MAX_LEN);
auto p = fbb->CreateVector(param->p, SECRET_MAX_LEN);
int32_t t = param->t;
int32_t g = param->g;
float dp_eps = param->dp_eps;
float dp_delta = param->dp_delta;
float dp_norm_clip = param->dp_norm_clip;
auto encrypt_type = fbb->CreateString(param->encrypt_type);
auto cipher_public_params =
schema::CreateCipherPublicParams(*fbb.get(), t, p, g, prime, dp_eps, dp_delta, dp_norm_clip, encrypt_type);
#endif
schema::FLPlanBuilder fl_plan_builder(*(fbb.get()));
fl_plan_builder.add_fl_name(fbs_fl_name);
fl_plan_builder.add_server_mode(fbs_server_mode);
@ -199,6 +217,9 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num());
fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size());
fl_plan_builder.add_lr(PSContext::instance()->client_learning_rate());
#ifdef ENABLE_ARMOUR
fl_plan_builder.add_cipher(cipher_public_params);
#endif
auto fbs_fl_plan = fl_plan_builder.Finish();
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;

View File

@ -18,6 +18,9 @@
#include <memory>
#include <string>
#include <csignal>
#ifdef ENABLE_ARMOUR
#include "armour/secure_protocol/secret_sharing.h"
#endif
#include "ps/server/round.h"
#include "ps/server/model_store.h"
#include "ps/server/iteration.h"
@ -39,7 +42,7 @@ void SignalHandler(int signal) {
}
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
const FuncGraphPtr &func_graph, size_t executor_threshold) {
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold) {
MS_EXCEPTION_IF_NULL(func_graph);
func_graph_ = func_graph;
@ -48,6 +51,7 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
return;
}
rounds_config_ = rounds_config;
cipher_config_ = cipher_config;
use_tcp_ = use_tcp;
use_http_ = use_http;
@ -80,6 +84,7 @@ void Server::Run() {
RegisterCommCallbacks();
StartCommunicator();
InitExecutor();
InitCipher();
RegisterRoundKernel();
MS_LOG(INFO) << "Server started successfully.";
safemode_ = false;
@ -177,6 +182,42 @@ void Server::InitIteration() {
iteration_->AddRound(round);
}
cipher_initial_client_cnt_ = rounds_config_[0].threshold_count;
cipher_exchange_secrets_cnt_ = cipher_initial_client_cnt_ * 1.0;
cipher_share_secrets_cnt_ = cipher_initial_client_cnt_ * cipher_config_.share_secrets_ratio;
cipher_get_clientlist_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_up_cnt_ = rounds_config_[1].threshold_count;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshhold;
MS_LOG(INFO) << "Initializing cipher:";
MS_LOG(INFO) << " cipher_initial_client_cnt_: " << cipher_initial_client_cnt_
<< " cipher_exchange_secrets_cnt_: " << cipher_exchange_secrets_cnt_
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
MS_LOG(INFO) << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_;
#ifdef ENABLE_ARMOUR
std::shared_ptr<Round> exchange_keys_round =
std::make_shared<Round>("exchangeKeys", false, 3000, true, cipher_exchange_secrets_cnt_);
iteration_->AddRound(exchange_keys_round);
std::shared_ptr<Round> get_keys_round =
std::make_shared<Round>("getKeys", false, 3000, true, cipher_exchange_secrets_cnt_);
iteration_->AddRound(get_keys_round);
std::shared_ptr<Round> share_secrets_round =
std::make_shared<Round>("shareSecrets", false, 3000, true, cipher_share_secrets_cnt_);
iteration_->AddRound(share_secrets_round);
std::shared_ptr<Round> get_secrets_round =
std::make_shared<Round>("getSecrets", false, 3000, true, cipher_share_secrets_cnt_);
iteration_->AddRound(get_secrets_round);
std::shared_ptr<Round> get_clientlist_round =
std::make_shared<Round>("getClientList", false, 3000, true, cipher_get_clientlist_cnt_);
iteration_->AddRound(get_clientlist_round);
std::shared_ptr<Round> reconstruct_secrets_round =
std::make_shared<Round>("reconstructSecrets", false, 3000, true, cipher_reconstruct_secrets_up_cnt_);
iteration_->AddRound(reconstruct_secrets_round);
#endif
// 2.Initialize all the rounds.
TimeOutCb time_out_cb =
std::bind(&Iteration::MoveToNextIteration, iteration_, std::placeholders::_1, std::placeholders::_2);
@ -186,6 +227,37 @@ void Server::InitIteration() {
return;
}
void Server::InitCipher() {
#ifdef ENABLE_ARMOUR
cipher_init_ = &armour::CipherInit::GetInstance();
int cipher_t = cipher_reconstruct_secrets_down_cnt_;
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
int cipher_g = 1;
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
mpz_t prim;
mpz_init(prim);
mindspore::armour::GetRandomPrime(prim);
mindspore::armour::PrintBigInteger(prim, 16);
size_t len_cipher_prime;
mpz_export((unsigned char *)cipher_prime, &len_cipher_prime, sizeof(unsigned char), 1, 0, 0, prim);
mindspore::armour::CipherPublicPara param;
param.g = cipher_g;
param.t = cipher_t;
memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
// param.dp_delta = dp_delta;
// param.dp_eps = dp_eps;
// param.dp_norm_clip = dp_norm_clip;
param.encrypt_type = kNotEncryptType; // PSContext::instance()->encrypt_type;
cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_,
cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
cipher_reconstruct_secrets_up_cnt_);
#endif
}
void Server::RegisterCommCallbacks() {
// The message callbacks of round kernels are already set in method InitIteration, so here we don't need to register
// rounds' callbacks.

View File

@ -26,6 +26,9 @@
#include "ps/server/common.h"
#include "ps/server/executor.h"
#include "ps/server/iteration.h"
#ifdef ENABLE_ARMOUR
#include "armour/cipher/cipher_init.h"
#endif
namespace mindspore {
namespace ps {
@ -39,7 +42,7 @@ class Server {
}
void Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
const FuncGraphPtr &func_graph, size_t executor_threshold);
const CipherConfig &cipher_config, const FuncGraphPtr &func_graph, size_t executor_threshold);
// According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
// blocked until the server is finalized.
@ -92,6 +95,9 @@ class Server {
// Initialize executor according to the server mode.
void InitExecutor();
// Initialize cipher according to the public param.
void InitCipher();
// Create round kernels and bind these kernels with corresponding Round.
void RegisterRoundKernel();
@ -120,6 +126,7 @@ class Server {
// The configure of all rounds.
std::vector<RoundConfig> rounds_config_;
CipherConfig cipher_config_;
// The graph passed by the frontend without backend optimizing.
FuncGraphPtr func_graph_;
@ -147,10 +154,25 @@ class Server {
std::atomic_bool safemode_;
// Variables set by ps context.
#ifdef ENABLE_ARMOUR
armour::CipherInit *cipher_init_;
#endif
std::string scheduler_ip_;
uint16_t scheduler_port_;
uint32_t server_num_;
uint32_t worker_num_;
uint16_t fl_server_port_;
size_t start_fl_job_cnt_;
size_t update_model_cnt_;
size_t cipher_initial_client_cnt_;
size_t cipher_exchange_secrets_cnt_;
size_t cipher_share_secrets_cnt_;
size_t cipher_get_clientlist_cnt_;
size_t cipher_reconstruct_secrets_up_cnt_;
size_t cipher_reconstruct_secrets_down_cnt_;
float percent_for_update_model_;
float percent_for_get_model_;
};
} // namespace server
} // namespace ps

View File

@ -836,14 +836,15 @@ def set_fl_context(**kwargs):
fl_server_port (int): The http port of the federated learning server.
Normally for each server this should be set to the same value. Default: 6668.
enable_fl_client (bool): Whether this process is federated learning client. Default: False.
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 0.
start_fl_job_threshold (int): The threshold count of startFLJob. Default: 1.
start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
update_model_ratio (float): The ratio for computing the threshold count of updateModel
which will be multiplied by start_fl_job_threshold.
Must be between 0 and 1.0.Default: 1.0.
share_secrets_ratio (float): The ratio for computing the threshold count of share secrets. Default: 1.0.
update_model_ratio (float): The ratio for computing the threshold count of updateModel. Default: 1.0.
get_model_ratio (float): The ratio for computing the threshold count of get model. Default: 1.0.
reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 0.
update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
fl_name (str): The federated learning job name. Default: ''.
fl_iteration_num (int): Iteration number of federeated learning,
fl_name (string): The federated learning job name. Default: ''.
fl_iteration_num (int): Iteration number of federated learning,
which is the number of interactions between client and server. Default: 20.
client_epoch_num (int): Client training epoch number. Default: 25.
client_batch_size (int): Client training data batch size. Default: 32.

View File

@ -45,6 +45,7 @@ static const std::vector<std::string> sub_module_names = {
"PROFILER", // SM_PROFILER
"PS", // SM_PS
"LITE", // SM_LITE
"ARMOUR", // SM_ARMOUR
"HCCL_ADPT", // SM_HCCL_ADPT
"MINDQUANTUM", // SM_MINDQUANTUM
"RUNTIME_FRAMEWORK", // SM_RUNTIME_FRAMEWORK

View File

@ -132,6 +132,7 @@ enum SubModuleId : int {
SM_PROFILER, // profiler
SM_PS, // Parameter Server
SM_LITE, // LITE
SM_ARMOUR, // ARMOUR
SM_HCCL_ADPT, // Hccl Adapter
SM_MINDQUANTUM, // MindQuantum
SM_RUNTIME_FRAMEWORK, // Runtime framework

View File

@ -57,6 +57,9 @@ _set_ps_context_func_map = {
"start_fl_job_time_window": ps_context().set_start_fl_job_time_window,
"update_model_ratio": ps_context().set_update_model_ratio,
"update_model_time_window": ps_context().set_update_model_time_window,
"share_secrets_ratio": ps_context().set_share_secrets_ratio,
"get_model_ratio": ps_context().set_get_model_ratio,
"reconstruct_secrets_threshhold": ps_context().set_reconstruct_secrets_threshhold,
"fl_name": ps_context().set_fl_name,
"fl_iteration_num": ps_context().set_fl_iteration_num,
"client_epoch_num": ps_context().set_client_epoch_num,

View File

@ -24,7 +24,7 @@ table CipherPublicParams {
dp_eps:float;
dp_delta:float;
dp_norm_clip:float;
encrypt_type:int;
encrypt_type:string;
}
table ClientPublicKeys {

View File

@ -28,6 +28,9 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--get_model_ratio", type=float, default=1.0)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
@ -48,6 +51,9 @@ if __name__ == "__main__":
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
share_secrets_ratio = args.share_secrets_ratio
get_model_ratio = args.get_model_ratio
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
@ -78,6 +84,9 @@ if __name__ == "__main__":
cmd_server += " --start_fl_job_time_window=" + str(start_fl_job_time_window)
cmd_server += " --update_model_ratio=" + str(update_model_ratio)
cmd_server += " --update_model_time_window=" + str(update_model_time_window)
cmd_server += " --share_secrets_ratio=" + str(share_secrets_ratio)
cmd_server += " --get_model_ratio=" + str(get_model_ratio)
cmd_server += " --reconstruct_secrets_threshhold=" + str(reconstruct_secrets_threshhold)
cmd_server += " --fl_name=" + fl_name
cmd_server += " --fl_iteration_num=" + str(fl_iteration_num)
cmd_server += " --client_epoch_num=" + str(client_epoch_num)

View File

@ -36,6 +36,9 @@ parser.add_argument("--start_fl_job_threshold", type=int, default=1)
parser.add_argument("--start_fl_job_time_window", type=int, default=3000)
parser.add_argument("--update_model_ratio", type=float, default=1.0)
parser.add_argument("--update_model_time_window", type=int, default=3000)
parser.add_argument("--share_secrets_ratio", type=float, default=1.0)
parser.add_argument("--get_model_ratio", type=float, default=1.0)
parser.add_argument("--reconstruct_secrets_threshhold", type=int, default=0)
parser.add_argument("--fl_name", type=str, default="Lenet")
parser.add_argument("--fl_iteration_num", type=int, default=25)
parser.add_argument("--client_epoch_num", type=int, default=20)
@ -56,6 +59,9 @@ start_fl_job_threshold = args.start_fl_job_threshold
start_fl_job_time_window = args.start_fl_job_time_window
update_model_ratio = args.update_model_ratio
update_model_time_window = args.update_model_time_window
share_secrets_ratio = args.share_secrets_ratio
get_model_ratio = args.get_model_ratio
reconstruct_secrets_threshhold = args.reconstruct_secrets_threshhold
fl_name = args.fl_name
fl_iteration_num = args.fl_iteration_num
client_epoch_num = args.client_epoch_num
@ -76,6 +82,9 @@ ctx = {
"start_fl_job_time_window": start_fl_job_time_window,
"update_model_ratio": update_model_ratio,
"update_model_time_window": update_model_time_window,
"share_secrets_ratio": share_secrets_ratio,
"get_model_ratio": get_model_ratio,
"reconstruct_secrets_threshhold": reconstruct_secrets_threshhold,
"fl_name": fl_name,
"fl_iteration_num": fl_iteration_num,
"client_epoch_num": client_epoch_num,

View File

@ -159,6 +159,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/ps/*.cc"
"../../../mindspore/ccsrc/profiler/device/common/*.cc"
"../../../mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/adam_fp32.c"
"../../../mindspore/ccsrc/armour/*.cc"
)
list(REMOVE_ITEM MINDSPORE_SRC_LIST