!27108 code sync in federated learning

Merge pull request !27108 from tan-wei-cheng-3260/develop-twc-sync2
This commit is contained in:
i-robot 2021-12-02 11:57:55 +00:00 committed by Gitee
commit 82542f1d4a
127 changed files with 4413 additions and 1219 deletions

View File

@ -30,7 +30,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
// The duration between two PullWeights requests when return code is ResponseCode_SucNotReady. // The duration between two PullWeight requests when return code is ResponseCode_SucNotReady.
constexpr int kRetryDurationOfPullWeights = 200; constexpr int kRetryDurationOfPullWeights = 200;
template <typename T> template <typename T>
class FusedPullWeightKernel : public CPUKernel { class FusedPullWeightKernel : public CPUKernel {
@ -52,11 +52,11 @@ class FusedPullWeightKernel : public CPUKernel {
total_iteration_++; total_iteration_++;
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration(); uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
if (step_num_per_iteration == 0) { if (step_num_per_iteration == 0) {
MS_LOG(EXCEPTION) << "Step numbers of per iteration should not be equal to 0"; MS_LOG(EXCEPTION) << "step number per iteration should not be 0";
} }
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
MS_LOG(INFO) << "Try to pull weights. Local step number: " << total_iteration_ MS_LOG(INFO) << "Try to pull weights. Local step number: " << total_iteration_
<< ", step number needs to run per iteration: " << step_num_per_iteration; << ", step number needs to run per iteration: " << step_num_per_iteration;
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
if (step_num_per_iteration != fl::kOneStepPerIteration && if (step_num_per_iteration != fl::kOneStepPerIteration &&
total_iteration_ % step_num_per_iteration != fl::kTrainBeginStepNum) { total_iteration_ % step_num_per_iteration != fl::kTrainBeginStepNum) {
return true; return true;
@ -86,6 +86,7 @@ class FusedPullWeightKernel : public CPUKernel {
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg); MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
pull_weight_rsp = flatbuffers::GetRoot<schema::ResponsePullWeight>(pull_weight_rsp_msg->data()); pull_weight_rsp = flatbuffers::GetRoot<schema::ResponsePullWeight>(pull_weight_rsp_msg->data());
MS_EXCEPTION_IF_NULL(pull_weight_rsp);
retcode = pull_weight_rsp->retcode(); retcode = pull_weight_rsp->retcode();
if (retcode == schema::ResponseCode_SucNotReady) { if (retcode == schema::ResponseCode_SucNotReady) {
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPullWeights)); std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPullWeights));
@ -95,11 +96,11 @@ class FusedPullWeightKernel : public CPUKernel {
// Recreate fbb to avoid memory leak of FlatBuffers. // Recreate fbb to avoid memory leak of FlatBuffers.
fbb = std::make_shared<fl::FBBuilder>(); fbb = std::make_shared<fl::FBBuilder>();
if (!BuildPullWeightReq(fbb)) { if (!BuildPullWeightReq(fbb)) {
MS_LOG(EXCEPTION) << "Building request for FusedDownloadWeightsByKeys failed."; MS_LOG(EXCEPTION) << "Building request for FusedPullWeight failed.";
} }
continue; continue;
} else if (retcode != schema::ResponseCode_SUCCEED) { } else if (retcode != schema::ResponseCode_SUCCEED) {
MS_LOG(EXCEPTION) << "FusedPullWeight failed. Server return code: " << pull_weight_rsp->retcode() MS_LOG(WARNING) << "FusedPullWeight failed. Server return code: " << pull_weight_rsp->retcode()
<< ", reason: " << pull_weight_rsp->reason()->str(); << ", reason: " << pull_weight_rsp->reason()->str();
} else { } else {
MS_LOG(DEBUG) << "FusedPullWeight succeed."; MS_LOG(DEBUG) << "FusedPullWeight succeed.";

View File

@ -28,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
// The duration between two PushWeights requests when return code is ResponseCode_SucNotReady. // The duration between two PushWeight requests when return code is ResponseCode_SucNotReady.
constexpr int kRetryDurationOfPushWeights = 200; constexpr int kRetryDurationOfPushWeights = 200;
template <typename T> template <typename T>
class FusedPushWeightKernel : public CPUKernel { class FusedPushWeightKernel : public CPUKernel {
@ -50,11 +50,11 @@ class FusedPushWeightKernel : public CPUKernel {
total_iteration_++; total_iteration_++;
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration(); uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
if (step_num_per_iteration == 0) { if (step_num_per_iteration == 0) {
MS_LOG(EXCEPTION) << "Step numbers of per iteration should not be equal to 0"; MS_LOG(EXCEPTION) << "step number per iterationb should not be 0";
} }
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
MS_LOG(INFO) << "Try to push weights. Local step number: " << total_iteration_ MS_LOG(INFO) << "Try to push weights. Local step number: " << total_iteration_
<< ", step number needs to run per iteration: " << step_num_per_iteration; << ", step number needs to run per iteration: " << step_num_per_iteration;
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
if (step_num_per_iteration != fl::kOneStepPerIteration && if (step_num_per_iteration != fl::kOneStepPerIteration &&
total_iteration_ % step_num_per_iteration != fl::kTrainEndStepNum) { total_iteration_ % step_num_per_iteration != fl::kTrainEndStepNum) {
return true; return true;
@ -87,6 +87,7 @@ class FusedPushWeightKernel : public CPUKernel {
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg); MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
push_weight_rsp = flatbuffers::GetRoot<schema::ResponsePushWeight>(push_weight_rsp_msg->data()); push_weight_rsp = flatbuffers::GetRoot<schema::ResponsePushWeight>(push_weight_rsp_msg->data());
MS_EXCEPTION_IF_NULL(push_weight_rsp);
retcode = push_weight_rsp->retcode(); retcode = push_weight_rsp->retcode();
if (retcode == schema::ResponseCode_SucNotReady) { if (retcode == schema::ResponseCode_SucNotReady) {
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushWeights)); std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushWeights));
@ -98,7 +99,7 @@ class FusedPushWeightKernel : public CPUKernel {
} }
continue; continue;
} else if (retcode != schema::ResponseCode_SUCCEED) { } else if (retcode != schema::ResponseCode_SUCCEED) {
MS_LOG(EXCEPTION) << "FusedPushWeight failed. Server return code: " << push_weight_rsp->retcode() MS_LOG(WARNING) << "FusedPushWeight failed. Server return code: " << push_weight_rsp->retcode()
<< ", reason: " << push_weight_rsp->reason()->str(); << ", reason: " << push_weight_rsp->reason()->str();
} else { } else {
MS_LOG(DEBUG) << "FusedPushWeight succeed."; MS_LOG(DEBUG) << "FusedPushWeight succeed.";

View File

@ -87,7 +87,7 @@ class PushMetricsKernel : public CPUKernel {
case schema::ResponseCode_OutOfTime: case schema::ResponseCode_OutOfTime:
break; break;
default: default:
MS_LOG(EXCEPTION) << "Launching push metrics for worker failed."; MS_LOG(WARNING) << "Launching push metrics for worker failed.";
} }
MS_LOG(INFO) << "Push metrics for loss and accuracy success."; MS_LOG(INFO) << "Push metrics for loss and accuracy success.";

View File

@ -5,7 +5,6 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/fed_avg_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/fed_avg_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/sgd_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel_factory.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel.cc")
@ -20,6 +19,8 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_secrets_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_secrets_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/reconstruct_secrets_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/reconstruct_secrets_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/share_secrets_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/share_secrets_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_list_sign_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_list_sign_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_metrics_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_metrics_kernel.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/params_info.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/params_info.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/consistent_hash_ring.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/consistent_hash_ring.cc")
@ -35,6 +36,8 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "server/model_store.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/model_store.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/round.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/round.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/server.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/server.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/cert_verify.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/server_recovery.cc")
list(REMOVE_ITEM _FL_SRC_FILES "server/iteration_metrics.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/iteration_metrics.cc")
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc") list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
@ -49,10 +52,6 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc")
endif() endif()
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-tautological-pointer-compare")
endif()
list(LENGTH _FL_SRC_FILES fl_file_num) list(LENGTH _FL_SRC_FILES fl_file_num)
if(NOT fl_file_num EQUAL 0) if(NOT fl_file_num EQUAL 0)
set_property(SOURCE ${_FL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_FL) set_property(SOURCE ${_FL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_FL)

View File

@ -24,8 +24,8 @@ namespace mindspore {
namespace armour { namespace armour {
bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_exchange_keys_cnt, bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt, size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt, size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt,
size_t cipher_reconstruct_secrets_up_cnt) { size_t cipher_get_list_sign_cnt, size_t cipher_reconstruct_secrets_up_cnt) {
MS_LOG(INFO) << "CipherInit::Init START"; MS_LOG(INFO) << "CipherInit::Init START";
if (publicparam_.p == nullptr || param.p == nullptr || param.prime == nullptr || publicparam_.prime == nullptr) { if (publicparam_.p == nullptr || param.p == nullptr || param.prime == nullptr || publicparam_.prime == nullptr) {
MS_LOG(ERROR) << "CipherInit::input data invalid."; MS_LOG(ERROR) << "CipherInit::input data invalid.";
@ -47,6 +47,8 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
get_secrets_threshold = cipher_get_secrets_cnt; get_secrets_threshold = cipher_get_secrets_cnt;
client_list_threshold = cipher_get_clientlist_cnt; client_list_threshold = cipher_get_clientlist_cnt;
reconstruct_secrets_threshold = cipher_reconstruct_secrets_up_cnt; reconstruct_secrets_threshold = cipher_reconstruct_secrets_up_cnt;
push_list_sign_threshold = cipher_push_list_sign_cnt;
get_list_sign_threshold = cipher_get_list_sign_cnt;
time_out_mutex_ = time_out_mutex; time_out_mutex_ = time_out_mutex;
publicparam_.dp_eps = param.dp_eps; publicparam_.dp_eps = param.dp_eps;
@ -74,6 +76,8 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
MS_LOG(INFO) << " CipherInit get_secrets_threshold : " << get_secrets_threshold; MS_LOG(INFO) << " CipherInit get_secrets_threshold : " << get_secrets_threshold;
MS_LOG(INFO) << " CipherInit client_list_threshold : " << client_list_threshold; MS_LOG(INFO) << " CipherInit client_list_threshold : " << client_list_threshold;
MS_LOG(INFO) << " CipherInit reconstruct_secrets_threshold : " << reconstruct_secrets_threshold; MS_LOG(INFO) << " CipherInit reconstruct_secrets_threshold : " << reconstruct_secrets_threshold;
MS_LOG(INFO) << " CipherInit push_list_sign_threshold : " << push_list_sign_threshold;
MS_LOG(INFO) << " CipherInit get_list_sign_threshold : " << get_list_sign_threshold;
MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_; MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_;
if (!Check_Parames()) { if (!Check_Parames()) {
MS_LOG(ERROR) << "Cipher parameters are illegal."; MS_LOG(ERROR) << "Cipher parameters are illegal.";
@ -81,7 +85,6 @@ bool CipherInit::Init(const CipherPublicPara &param, size_t time_out_mutex, size
} }
MS_LOG(INFO) << " CipherInit::Init Success"; MS_LOG(INFO) << " CipherInit::Init Success";
} }
if (param.encrypt_type == mindspore::ps::kStablePWEncryptType) { if (param.encrypt_type == mindspore::ps::kStablePWEncryptType) {
cipher_meta_storage_.RegisterStablePWClass(); cipher_meta_storage_.RegisterStablePWClass();
MS_LOG(INFO) << "Register metadata for StablePWEncrypt is finished."; MS_LOG(INFO) << "Register metadata for StablePWEncrypt is finished.";
@ -96,9 +99,10 @@ bool CipherInit::Check_Parames() {
} }
if (share_secrets_threshold < reconstruct_secrets_threshold) { if (share_secrets_threshold < reconstruct_secrets_threshold) {
MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger " MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger than "
"than share_secrets_threshold, but got they are:" "share_secrets_threshold."
<< reconstruct_secrets_threshold << ", " << share_secrets_threshold; << "reconstruct_secrets_threshold: " << reconstruct_secrets_threshold
<< ", share_secrets_threshold: " << share_secrets_threshold;
return false; return false;
} }

View File

@ -40,27 +40,34 @@ class CipherInit {
// Initialize the parameters of the secure aggregation. // Initialize the parameters of the secure aggregation.
bool Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_exchange_keys_cnt, bool Init(const CipherPublicPara &param, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt, size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt, size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt, size_t cipher_get_list_sign_cnt,
size_t cipher_reconstruct_secrets_up_cnt); size_t cipher_reconstruct_secrets_up_cnt);
// Get public params. which is given to start fl job thread. // Get public params. which is given to start fl job thread.
CipherPublicPara *GetPublicParams() { return &publicparam_; } CipherPublicPara *GetPublicParams() { return &publicparam_; }
size_t share_secrets_threshold; // the minimum number of clients to share secret fragments. size_t share_secrets_threshold; // the minimum number of clients to share
size_t reconstruct_secrets_threshold; // the minimum number of clients to reconstruct secret mask. // secret fragments.
size_t exchange_key_threshold; // the minimum number of clients to send public keys. size_t reconstruct_secrets_threshold; // the minimum number of clients to
size_t push_list_sign_threshold; // the minimum number of clients to push client list signature. // reconstruct secret mask.
size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask. size_t exchange_key_threshold; // the minimum number of clients to send public
// keys.
size_t push_list_sign_threshold; // the minimum number of clients to push
// client list signature.
size_t secrets_minnums_; // the minimum number of secret fragment s to
// reconstruct secret mask.
size_t featuremap_; // the size of data to deal. size_t featuremap_; // the size of data to deal.
CipherPublicPara publicparam_; // the param containing encrypted public parameters. CipherPublicPara publicparam_; // the param containing encrypted public parameters.
CipherMetaStorage cipher_meta_storage_; CipherMetaStorage cipher_meta_storage_;
private: private:
size_t client_list_threshold; // the minimum number of clients to get update model client list. size_t client_list_threshold; // the minimum number of clients to get update
// model client list.
size_t get_key_threshold; // the minimum number of clients to get public keys. size_t get_key_threshold; // the minimum number of clients to get public keys.
size_t get_list_sign_threshold; // the minimum number of clients to get client list signature. size_t get_list_sign_threshold; // the minimum number of clients to get client
size_t get_secrets_threshold; // the minimum number of clients to get secret fragments. // list signature.
size_t get_secrets_threshold; // the minimum number of clients to get secret
// fragments.
size_t time_out_mutex_; // timeout mutex. size_t time_out_mutex_; // timeout mutex.
// Check whether the parameters are valid. // Check whether the parameters are valid.

View File

@ -0,0 +1,633 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fl/server/cert_verify.h"
#include <sys/time.h>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <vector>
#include <iomanip>
#include <sstream>
namespace mindspore {
namespace ps {
namespace server {
#ifndef _WIN32
static int64_t replayAttackTimeDiff;
X509 *CertVerify::readCertFromFile(const std::string &certPath) {
BIO *bio = BIO_new_file(certPath.c_str(), "r");
X509 *certObj = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
BIO_free_all(bio);
return certObj;
}
X509 *CertVerify::readCertFromPerm(std::string cert) {
BIO *bio = BIO_new_mem_buf(reinterpret_cast<void *>(cert.data()), -1);
X509 *certObj = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
BIO_free_all(bio);
return certObj;
}
X509_CRL *CertVerify::readCrlFromFile(const std::string &crlPath) {
BIO *bio = BIO_new_file(crlPath.c_str(), "r");
X509_CRL *crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
BIO_free_all(bio);
return crl;
}
bool checkFileExists(const std::string &file) {
std::ifstream f(file.c_str());
if (!f.good()) {
return false;
} else {
f.close();
return true;
}
}
bool CertVerify::verifyCertTime(const X509 *cert) const {
ASN1_TIME *start = X509_getm_notBefore(cert);
ASN1_TIME *end = X509_getm_notAfter(cert);
int day = 0;
int sec = 0;
int ret = ASN1_TIME_diff(&day, &sec, start, NULL);
if (ret != 1) {
return false;
}
if (day < 0 || sec < 0) {
MS_LOG(ERROR) << "cert start time is later than now time.";
return false;
}
day = 0;
sec = 0;
ret = ASN1_TIME_diff(&day, &sec, NULL, end);
if (ret != 1) {
return false;
}
if (day < 0 || sec < 0) {
MS_LOG(ERROR) << "cert end time is sooner than now time.";
return false;
}
return true;
}
bool CertVerify::verifyPublicKey(const X509 *keyAttestationCertObj, const X509 *equipCertObj,
const X509 *equipCACertObj, const X509 *rootFirstCA, const X509 *rootSecondCA) const {
bool result = true;
EVP_PKEY *equipPubKey = X509_get_pubkey(const_cast<X509 *>(equipCertObj));
EVP_PKEY *equipCAPubKey = X509_get_pubkey(const_cast<X509 *>(equipCACertObj));
EVP_PKEY *rootFirstPubKey = X509_get_pubkey(const_cast<X509 *>(rootFirstCA));
EVP_PKEY *rootSecondPubKey = X509_get_pubkey(const_cast<X509 *>(rootSecondCA));
do {
int ret = 0;
ret = X509_verify(const_cast<X509 *>(keyAttestationCertObj), equipPubKey);
if (ret != 1) {
MS_LOG(ERROR) << "keyAttestationCert verify is failed";
result = false;
break;
}
ret = X509_verify(const_cast<X509 *>(equipCertObj), equipCAPubKey);
if (ret != 1) {
MS_LOG(ERROR) << "equip cert verify is failed";
result = false;
break;
}
int ret_first = X509_verify(const_cast<X509 *>(equipCACertObj), rootFirstPubKey);
int ret_second = X509_verify(const_cast<X509 *>(equipCACertObj), rootSecondPubKey);
if (ret_first != 1 && ret_second != 1) {
MS_LOG(ERROR) << "equip ca cert verify is failed";
result = false;
break;
}
} while (0);
EVP_PKEY_free(equipPubKey);
EVP_PKEY_free(equipCAPubKey);
EVP_PKEY_free(rootFirstPubKey);
EVP_PKEY_free(rootSecondPubKey);
MS_LOG(INFO) << "verify Public Key success.";
return result;
}
bool CertVerify::verifyCAChain(const std::string &keyAttestation, const std::string &equipCert,
const std::string &equipCACert, const std::string &rootFirstCAPath,
const std::string &rootSecondCAPath) {
X509 *rootFirstCA = CertVerify::readCertFromFile(rootFirstCAPath);
X509 *rootSecondCA = CertVerify::readCertFromFile(rootSecondCAPath);
X509 *keyAttestationCertObj = readCertFromPerm(keyAttestation);
X509 *equipCertObj = readCertFromPerm(equipCert);
X509 *equipCACertObj = readCertFromPerm(equipCACert);
bool result = true;
do {
if (rootFirstCA == nullptr || rootSecondCA == nullptr) {
MS_LOG(ERROR) << "rootFirstCA or rootSecondCA is nullptr";
result = false;
break;
}
if (keyAttestationCertObj == nullptr || equipCertObj == nullptr || equipCACertObj == nullptr) {
result = false;
break;
}
if (!verifyCertTime(keyAttestationCertObj) || !verifyCertTime(equipCertObj) || !verifyCertTime(equipCACertObj)) {
result = false;
break;
}
if (!verifyCertCommonName(equipCACertObj, equipCertObj)) {
MS_LOG(ERROR) << "equip ca cert subject cn is not equal with equip cert issuer cn.";
result = false;
break;
}
if (!verifyCertCommonName(rootFirstCA, equipCACertObj) && !verifyCertCommonName(rootSecondCA, equipCACertObj)) {
MS_LOG(ERROR) << "root CA cert subject cn is not equal with equip CA cert issuer cn.";
result = false;
break;
}
if (!verifyExtendedAttributes(equipCACertObj)) {
MS_LOG(ERROR) << "verify equipCACert Extended Attributes failed.";
result = false;
break;
}
if (!verifyCertKeyID(rootFirstCA, equipCACertObj) && !verifyCertKeyID(rootSecondCA, equipCACertObj)) {
MS_LOG(ERROR) << "root CA cert subject keyid is not equal with equip CA cert issuer keyid.";
result = false;
break;
}
if (!verifyCertKeyID(equipCACertObj, equipCertObj)) {
MS_LOG(ERROR) << "equip CA cert subject keyid is not equal with equip cert issuer keyid.";
result = false;
break;
}
if (!verifyPublicKey(keyAttestationCertObj, equipCertObj, equipCACertObj, rootFirstCA, rootSecondCA)) {
MS_LOG(ERROR) << "verify Public Key failed";
result = false;
break;
}
} while (0);
X509_free(rootFirstCA);
X509_free(rootSecondCA);
X509_free(keyAttestationCertObj);
X509_free(equipCertObj);
X509_free(equipCACertObj);
MS_LOG(INFO) << "verifyCAChain success.";
return result;
}
bool CertVerify::verifyCertKeyID(const X509 *caCert, const X509 *subCert) const {
bool result = true;
ASN1_OCTET_STRING *skid = nullptr;
AUTHORITY_KEYID *akeyid = nullptr;
do {
int crit = 0;
skid = reinterpret_cast<ASN1_OCTET_STRING *>(X509_get_ext_d2i(caCert, NID_subject_key_identifier, &crit, NULL));
if (skid == nullptr) {
result = false;
break;
}
char subject_keyid[512] = {0};
for (int i = 0; i < skid->length; i++) {
char keyid[8] = {0};
int base = 512;
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)skid->data[i]);
int ret = strcat_s(subject_keyid, base, keyid);
if (ret == -1) {
result = false;
break;
}
}
akeyid = reinterpret_cast<AUTHORITY_KEYID *>(X509_get_ext_d2i(subCert, NID_authority_key_identifier, &crit, NULL));
if (akeyid == nullptr) {
result = false;
break;
}
char issuer_keyid[512] = {0};
if (akeyid->keyid == nullptr) {
MS_LOG(ERROR) << "keyid is nullprt.";
result = false;
break;
}
for (int i = 0; i < akeyid->keyid->length; i++) {
char keyid[8] = {0};
int base = 512;
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)(akeyid->keyid->data[i]));
int ret = strcat_s(issuer_keyid, base, keyid);
if (ret == -1) {
result = false;
break;
}
}
std::string subject_keyid_str = subject_keyid;
std::string issuer_keyid_str = issuer_keyid;
if (subject_keyid_str != issuer_keyid_str) {
result = false;
break;
}
} while (0);
ASN1_OCTET_STRING_free(skid);
AUTHORITY_KEYID_free(akeyid);
return result;
}
bool CertVerify::verifyExtendedAttributes(const X509 *cert) const {
bool result = true;
BASIC_CONSTRAINTS *bcons = nullptr;
ASN1_BIT_STRING *lASN1UsageStr = nullptr;
do {
int cirt = 0;
bcons = reinterpret_cast<BASIC_CONSTRAINTS *>(X509_get_ext_d2i(cert, NID_basic_constraints, &cirt, NULL));
if (bcons == nullptr) {
result = false;
break;
}
if (!bcons->ca) {
MS_LOG(ERROR) << "Subject Type is End Entity.";
result = false;
break;
}
MS_LOG(INFO) << "Subject Type is CA.";
lASN1UsageStr = reinterpret_cast<ASN1_BIT_STRING *>(X509_get_ext_d2i(cert, NID_key_usage, NULL, NULL));
if (lASN1UsageStr == nullptr) {
result = false;
break;
}
int16_t usage = lASN1UsageStr->data[0];
if (lASN1UsageStr->length > 1) {
const unsigned int move = 8;
usage |= lASN1UsageStr->data[1] << move;
}
if (!(usage & KU_KEY_CERT_SIGN)) {
MS_LOG(ERROR) << "Subject is not Certificate Signature.";
result = false;
break;
}
MS_LOG(INFO) << "Subject is Certificate Signature.";
} while (0);
BASIC_CONSTRAINTS_free(bcons);
ASN1_BIT_STRING_free(lASN1UsageStr);
return result;
}
bool CertVerify::verifyCertCommonName(const X509 *caCert, const X509 *subCert) const {
if (caCert == nullptr || subCert == nullptr) {
return false;
}
char caSubjectCN[256] = "";
char subIssuerCN[256] = "";
X509_NAME *caSubjectX509CN = X509_get_subject_name(caCert);
X509_NAME *subIssuerX509CN = X509_get_issuer_name(subCert);
int ret = X509_NAME_get_text_by_NID(caSubjectX509CN, NID_commonName, caSubjectCN, sizeof(caSubjectCN));
if (ret < 0) {
return false;
}
ret = X509_NAME_get_text_by_NID(subIssuerX509CN, NID_commonName, subIssuerCN, sizeof(subIssuerCN));
if (ret < 0) {
return false;
}
std::string caSubjectCNStr = caSubjectCN;
std::string subIssuerCNStr = subIssuerCN;
if (caSubjectCNStr != subIssuerCNStr) {
return false;
}
return true;
}
bool CertVerify::verifyCRL(const std::string &equipCert, const std::string &equipCrlPath) {
if (!checkFileExists(equipCrlPath)) {
return true;
}
bool result = true;
X509_CRL *equipCrl = nullptr;
X509 *equipCertObj = nullptr;
EVP_PKEY *evp_pkey = nullptr;
do {
equipCrl = CertVerify::readCrlFromFile(equipCrlPath);
equipCertObj = readCertFromPerm(equipCert);
if (equipCertObj == nullptr) {
result = false;
break;
}
if (equipCrl == nullptr) {
MS_LOG(INFO) << "equipCrl is nullptr. return true.";
result = true;
break;
}
evp_pkey = X509_get_pubkey(equipCertObj);
int ret = X509_CRL_verify(equipCrl, evp_pkey);
if (ret == 1) {
MS_LOG(ERROR) << "equip cert in equip crl, verify failed";
result = false;
break;
}
} while (0);
EVP_PKEY_free(evp_pkey);
X509_free(equipCertObj);
X509_CRL_free(equipCrl);
MS_LOG(INFO) << "verifyCRL success.";
return result;
}
bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned char *signData, const std::string &flID,
const std::string &timeStamp) {
if (keyAttestation.empty() || signData == nullptr || flID.empty() || timeStamp.empty()) {
MS_LOG(ERROR) << "keyAttestation or signData or flID or timeStamp is empty.";
return false;
}
bool result = true;
X509 *keyAttestationCertObj = nullptr;
EVP_PKEY *pubKey = nullptr;
do {
keyAttestationCertObj = readCertFromPerm(keyAttestation);
std::string srcData = flID + " " + timeStamp;
// SHA256_DIGEST_LENGTH is 32
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
sha256Hash(srcData, srcDataHash, SHA256_DIGEST_LENGTH);
pubKey = X509_get_pubkey(keyAttestationCertObj);
RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey);
if (pRSAPublicKey == nullptr) {
MS_LOG(ERROR) << "get rsa public key failed.";
result = false;
break;
}
int pubKeyLen = RSA_size(pRSAPublicKey);
unsigned char buffer[256];
int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING);
if (ret == -1) {
MS_LOG(ERROR) << "rsa public decrypt failed.";
result = false;
break;
}
int saltLen = -2;
ret = RSA_verify_PKCS1_PSS(pRSAPublicKey, srcDataHash, EVP_sha256(), buffer, saltLen);
if (ret != 1) {
int64_t ulErr = SizeToLong(ERR_get_error());
char szErrMsg[1024] = {0};
MS_LOG(ERROR) << "verify error. error number: " << ulErr;
std::string str_res = ERR_error_string(ulErr, szErrMsg);
MS_LOG(ERROR) << szErrMsg;
if (str_res.empty()) {
result = false;
break;
}
result = false;
break;
}
} while (0);
EVP_PKEY_free(pubKey);
X509_free(keyAttestationCertObj);
CRYPTO_cleanup_all_ex_data();
MS_LOG(INFO) << "verifyRSAKey success.";
return result;
}
void CertVerify::sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash, const int len) const {
if (len <= 0) {
return;
}
SHA256_CTX sha_ctx;
int ret = SHA256_Init(&sha_ctx);
if (ret != 1) {
return;
}
ret = SHA256_Update(&sha_ctx, src, src_len);
if (ret != 1) {
return;
}
ret = SHA256_Final(hash, &sha_ctx);
if (ret != 1) {
return;
}
}
std::string CertVerify::toHexString(const unsigned char *data, const int len) {
if (data == nullptr) {
MS_LOG(ERROR) << "data hash is null.";
return "";
}
if (len <= 0) {
return "";
}
std::stringstream ss;
int base = 2;
for (int i = 0; i < len; i++) {
ss << std::hex << std::setw(base) << std::setfill('0') << static_cast<int>(data[i]);
}
return ss.str();
}
bool CertVerify::verifyEquipCertAndFlID(const std::string &flID, const std::string &equipCert) {
unsigned char hash[SHA256_DIGEST_LENGTH] = {""};
sha256Hash(equipCert, hash, SHA256_DIGEST_LENGTH);
std::string equipCertSha256 = toHexString(hash, SHA256_DIGEST_LENGTH);
if (flID == equipCertSha256) {
MS_LOG(INFO) << "verifyEquipCertAndFlID success.";
return true;
} else {
MS_LOG(ERROR) << "verifyEquipCertAndFlID failed.";
return false;
}
}
bool CertVerify::verifyTimeStamp(const std::string &flID, const std::string &timeStamp) const {
int64_t requestTime = std::stoll(timeStamp.c_str());
const int64_t base = 1000;
struct timeval tv {};
int ret = gettimeofday(&tv, nullptr);
if (ret != 0) {
return false;
}
int64_t now = tv.tv_sec * base + tv.tv_usec / base;
MS_LOG(INFO) << "flID: " << flID.c_str() << ",now time: " << now << ",requestTime: " << requestTime;
int64_t diff = now - requestTime;
if (diff > replayAttackTimeDiff || diff < 0) {
return false;
}
MS_LOG(INFO) << "verifyTimeStamp success.";
return true;
}
void CertVerify::sha256Hash(const std::string &src, uint8_t *hash, const int len) const {
if (len <= 0) {
return;
}
SHA256_CTX sha_ctx;
int ret = SHA256_Init(&sha_ctx);
if (ret != 1) {
return;
}
ret = SHA256_Update(&sha_ctx, src.c_str(), src.size());
if (ret != 1) {
return;
}
ret = SHA256_Final(hash, &sha_ctx);
if (ret != 1) {
return;
}
}
bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData,
int srcDataLen) {
if (keyAttestation.empty() || signData == nullptr || srcData == nullptr || srcDataLen <= 0) {
MS_LOG(ERROR) << "keyAttestation or signData or srcData is invalid.";
return false;
}
bool result = true;
X509 *keyAttestationCertObj = nullptr;
EVP_PKEY *pubKey = nullptr;
do {
keyAttestationCertObj = readCertFromPerm(keyAttestation);
pubKey = X509_get_pubkey(keyAttestationCertObj);
RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey);
if (pRSAPublicKey == nullptr) {
MS_LOG(ERROR) << "get rsa public key failed.";
result = false;
break;
}
int pubKeyLen = RSA_size(pRSAPublicKey);
unsigned char buffer[256];
int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING);
if (ret == -1) {
MS_LOG(ERROR) << "rsa public decrypt failed.";
result = false;
break;
}
int saltLen = -2;
ret = RSA_verify_PKCS1_PSS(pRSAPublicKey, srcData, EVP_sha256(), buffer, saltLen);
if (ret != 1) {
int64_t ulErr = SizeToLong(ERR_get_error());
char szErrMsg[1024] = {0};
MS_LOG(ERROR) << "verify error. error number: " << ulErr;
std::string str_res = ERR_error_string(ulErr, szErrMsg);
MS_LOG(ERROR) << szErrMsg;
if (str_res.empty()) {
result = false;
break;
}
result = false;
break;
}
} while (0);
EVP_PKEY_free(pubKey);
X509_free(keyAttestationCertObj);
CRYPTO_cleanup_all_ex_data();
MS_LOG(INFO) << "verifyRSAKey success.";
return result;
}
bool CertVerify::initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath,
const std::string equipCrlPath, const uint64_t replay_attack_time_diff) {
if (rootFirstCaFilePath.empty() || rootSecondCaFilePath.empty()) {
MS_LOG(ERROR) << "the root or crl path is empty.";
return false;
}
if (!checkFileExists(rootFirstCaFilePath)) {
MS_LOG(ERROR) << "The rootFirstCaFilePath is not exist.";
return false;
}
if (!checkFileExists(rootSecondCaFilePath)) {
MS_LOG(ERROR) << "The rootSecondCaFilePath is not exist.";
return false;
}
replayAttackTimeDiff = UlongToLong(replay_attack_time_diff);
return true;
}
bool CertVerify::verifyCertAndSign(const std::string &flID, const std::string &timeStamp, const unsigned char *signData,
const std::string &keyAttestation, const std::string &equipCert,
const std::string &equipCACert, const std::string &rootFirstCAPath,
const std::string &rootSecondCAPath, const std::string &equipCrlPath) {
if (!verifyEquipCertAndFlID(flID, equipCert)) {
return false;
}
if (!verifyCAChain(keyAttestation, equipCert, equipCACert, rootFirstCAPath, rootSecondCAPath)) {
return false;
}
if (!verifyCRL(equipCert, equipCrlPath)) {
return false;
}
if (!verifyRSAKey(keyAttestation, signData, flID, timeStamp)) {
return false;
}
if (!verifyTimeStamp(flID, timeStamp)) {
return false;
}
return true;
}
#else
bool CertVerify::verifyTimeStamp(const std::string &flID, const std::string &timeStamp) const {
MS_LOG(WARNING) << "verifyTimeStamp in win32 platform.";
return false;
}
void CertVerify::sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash, const int len) const {
MS_LOG(WARNING) << "sha256Hash in win32 platform.";
}
bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData,
int srcDataLen) {
MS_LOG(WARNING) << "verifyRSAKey in win32 platform.";
return false;
}
bool CertVerify::initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath,
const std::string equipCrlPath, const uint64_t replay_attack_time_diff) {
MS_LOG(WARNING) << "initRootCertAndCRL in win32 platform.";
return false;
}
bool CertVerify::verifyCertAndSign(const std::string &flID, const std::string &timeStamp, const unsigned char *signData,
const std::string &keyAttestation, const std::string &equipCert,
const std::string &equipCACert, const std::string &rootFirstCAPath,
const std::string &rootSecondCAPath, const std::string &equipCrlPath) {
MS_LOG(WARNING) << "verifyCertAndSign in win32 platform.";
return false;
}
#endif
} // namespace server
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,105 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_SERVER_CERT_VERIFY_H
#define MINDSPORE_CCSRC_FL_SERVER_CERT_VERIFY_H
#include <assert.h>
#ifndef _WIN32
#include <openssl/evp.h>
#include <openssl/rsa.h>
#include <openssl/x509v3.h>
#include <openssl/err.h>
#include <openssl/pem.h>
#include <openssl/sha.h>
#endif
#include <iostream>
#include <fstream>
#include <string>
#include "utils/log_adapter.h"
#include "fl/server/common.h"
namespace mindspore {
namespace ps {
namespace server {
class CertVerify {
public:
CertVerify() {}
~CertVerify() = default;
bool verifyCertAndSign(const std::string &flID, const std::string &timeStamp, const unsigned char *signData,
const std::string &keyAttestation, const std::string &equipCert,
const std::string &equipCACert, const std::string &rootFirstCAPath,
const std::string &rootSecondCAPath, const std::string &equipCrlPath);
static bool initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath,
const std::string equipCrlPath, uint64_t replay_attack_time_diff_);
// verify valid of sign data
bool verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData, int srcDataLen);
void sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash, const int len) const;
// verify valid of time stamp of request
bool verifyTimeStamp(const std::string &flID, const std::string &timeStamp) const;
#ifndef _WIN32
private:
// read certificate from file path
static X509 *readCertFromFile(const std::string &certPath);
// read Certificate Revocation List from file absolute path
static X509_CRL *readCrlFromFile(const std::string &crlPath);
// read certificate from pem string
X509 *readCertFromPerm(std::string cert);
// verify valid of certificate time
bool verifyCertTime(const X509 *cert) const;
// verify valid of certificate chain
bool verifyCAChain(const std::string &keyAttestation, const std::string &equipCert, const std::string &equipCACert,
const std::string &rootFirstCAPath, const std::string &rootSecondCAPath);
// verify valid of sign data
bool verifyRSAKey(const std::string &keyAttestation, const unsigned char *signData, const std::string &flID,
const std::string &timeStamp);
// verify valid of equip certificate with CRL
bool verifyCRL(const std::string &equipCert, const std::string &equipCrlPath);
// verify valid of flID with sha256(equip cert)
bool verifyEquipCertAndFlID(const std::string &flID, const std::string &equipCert);
void sha256Hash(const std::string &src, uint8_t *hash, const int len) const;
std::string toHexString(const unsigned char *data, const int len);
bool verifyCertCommonName(const X509 *caCert, const X509 *subCert) const;
bool verifyExtendedAttributes(const X509 *cert) const;
bool verifyCertKeyID(const X509 *caCert, const X509 *subCert) const;
bool verifyPublicKey(const X509 *keyAttestationCertObj, const X509 *equipCertObj, const X509 *equipCACertObj,
const X509 *rootFirstCA, const X509 *rootSecondCA) const;
#endif
};
} // namespace server
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_SERVER_CERT_VERIFY_H

View File

@ -93,7 +93,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
return false; return false;
} }
// Step 3: Reduce the data so we can overlap the time cost of send. // Step 3: Reduce the data so we can overlap the time cost of send.
for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) { for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) {
recv_chunk[j] += tmp_recv_chunk[j]; recv_chunk[j] += tmp_recv_chunk[j];
@ -164,9 +163,9 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
for (uint32_t i = 1; i < rank_size; i++) { for (uint32_t i = 1; i < rank_size; i++) {
std::shared_ptr<std::vector<unsigned char>> recv_str; std::shared_ptr<std::vector<unsigned char>> recv_str;
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i; MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str); auto recv_req_id1 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { if (!server_node_->CollectiveWait(recv_req_id1, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << recv_req_id1 << " failed.";
return false; return false;
} }
ret = memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size()); ret = memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size());
@ -180,9 +179,9 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
} }
} else { } else {
MS_LOG(DEBUG) << "Reduce send data to rank 0 process."; MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T)); auto send_req_id1 = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { if (!server_node_->Wait(send_req_id1, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << send_req_id1 << " failed.";
return false; return false;
} }
} }
@ -193,19 +192,19 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
if (rank_id_ == 0) { if (rank_id_ == 0) {
for (uint32_t i = 1; i < rank_size; i++) { for (uint32_t i = 1; i < rank_size; i++) {
MS_LOG(DEBUG) << "Broadcast data to process " << i; MS_LOG(DEBUG) << "Broadcast data to process " << i;
auto send_req_id = auto send_req_id2 =
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T)); server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { if (!server_node_->Wait(send_req_id2, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << send_req_id2 << " failed.";
return false; return false;
} }
} }
} else { } else {
MS_LOG(DEBUG) << "Broadcast receive from rank 0."; MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
std::shared_ptr<std::vector<unsigned char>> recv_str; std::shared_ptr<std::vector<unsigned char>> recv_str;
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str); auto recv_req_id2 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { if (!server_node_->CollectiveWait(recv_req_id2, kCollectiveCommTimeout)) {
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; MS_LOG(ERROR) << "CollectiveWait " << recv_req_id2 << " failed.";
return false; return false;
} }
ret = memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size()); ret = memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
@ -219,7 +218,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
} }
template <typename T> template <typename T>
bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count) { bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff, size_t send_count) {
MS_ERROR_IF_NULL_W_RET_VAL(node_, false); MS_ERROR_IF_NULL_W_RET_VAL(node_, false);
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false); MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false); MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
@ -230,8 +229,8 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size
// Store offsets to get every data chunk's address. // Store offsets to get every data chunk's address.
std::vector<size_t> chunk_offset; std::vector<size_t> chunk_offset;
for (size_t i = 0; i < rank_size_; i++) { for (size_t i = 0; i < rank_size_; i++) {
size_t ofs = size_t ofs = std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + SizeToLong(i), static_cast<size_t>(0),
std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast<size_t>(0), std::plus<size_t>()); std::plus<size_t>());
chunk_offset.push_back(ofs); chunk_offset.push_back(ofs);
} }
@ -295,7 +294,7 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c
MS_LOG(ERROR) << "The group is empty."; MS_LOG(ERROR) << "The group is empty.";
return false; return false;
} }
uint32_t group_rank_size = group_info.group_ranks.size(); uint32_t group_rank_size = SizeToUint(group_info.group_ranks.size());
uint32_t global_root_rank = group_to_global_ranks[root]; uint32_t global_root_rank = group_to_global_ranks[root];
// Broadcast data to processes which are not the root. // Broadcast data to processes which are not the root.
@ -349,7 +348,7 @@ bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t c
} }
template <typename T> template <typename T>
bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t send_count, bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *const recvbuff, size_t send_count,
const std::shared_ptr<ps::core::AbstractNode> &node) { const std::shared_ptr<ps::core::AbstractNode> &node) {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);
MS_ERROR_IF_NULL_W_RET_VAL(node, false); MS_ERROR_IF_NULL_W_RET_VAL(node, false);
@ -362,10 +361,10 @@ bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t s
rank_id_ = node_->rank_id(); rank_id_ = node_->rank_id();
switch (node_role_) { switch (node_role_) {
case ps::core::WORKER: case ps::core::WORKER:
rank_size_ = node_->worker_num(); rank_size_ = IntToUint(node_->worker_num());
break; break;
case ps::core::SERVER: case ps::core::SERVER:
rank_size_ = node_->server_num(); rank_size_ = IntToUint(node_->server_num());
break; break;
default: default:
MS_LOG(ERROR) << "The node role " << node_role_ << " for collective communication is invalid."; MS_LOG(ERROR) << "The node role " << node_role_ << " for collective communication is invalid.";
@ -380,7 +379,7 @@ bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t s
} }
template <typename T> template <typename T>
bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root, bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *const recvbuff, size_t count, uint32_t root,
const std::shared_ptr<ps::core::AbstractNode> &node, const std::shared_ptr<ps::core::AbstractNode> &node,
const CommunicationGroupInfo &group_info) { const CommunicationGroupInfo &group_info) {
std::unique_lock<std::mutex> lock(mtx_); std::unique_lock<std::mutex> lock(mtx_);

View File

@ -45,17 +45,17 @@ enum CommType { HTTP = 0, TCP };
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
struct RoundConfig { struct RoundConfig {
// The name of round. Please refer to round kernel *.cc files. // The name of the round. Please refer to round kernel *.cc files.
std::string name; std::string name;
// Whether this round has the time window limit. // Whether this round has the time window limit.
bool check_timeout = false; bool check_timeout = false;
// The length of the time window. Only used when check_timeout is set to true. // The length of the time window. Only used when check_timeout is set to true.
size_t time_window = 3000; size_t time_window = 3000;
// Whether this round has to check the request count has reached the threshold. // Whether this round has to check the request count has reach the threshold.
bool check_count = false; bool check_count = false;
// This round's request threshold count. Only used when check_count is set to true. // This round's request threshold count. Only used when threshold_count is set to true.
size_t threshold_count = 0; size_t threshold_count = 0;
// Whether this round uses the server as threshold count. This is vital for some rounds in elastic scaling scenario. // Whether this round uses the server number as threshold. This is vital for some rounds in elastic scaling scenario.
bool server_num_as_threshold = false; bool server_num_as_threshold = false;
}; };
@ -67,6 +67,8 @@ struct CipherConfig {
size_t share_secrets_threshold = 0; size_t share_secrets_threshold = 0;
size_t get_secrets_threshold = 0; size_t get_secrets_threshold = 0;
size_t client_list_threshold = 0; size_t client_list_threshold = 0;
size_t push_list_sign_threshold = 0;
size_t get_list_sign_threshold = 0;
size_t reconstruct_secrets_threshold = 0; size_t reconstruct_secrets_threshold = 0;
}; };
@ -130,7 +132,6 @@ constexpr auto kAdamEps = "eps";
constexpr auto kFtrlLinear = "linear"; constexpr auto kFtrlLinear = "linear";
constexpr auto kDataSize = "data_size"; constexpr auto kDataSize = "data_size";
constexpr auto kNewDataSize = "new_data_size"; constexpr auto kNewDataSize = "new_data_size";
constexpr auto kStat = "stat";
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
// launched. // launched.
@ -175,14 +176,11 @@ const OptimParamNameToIndex kAdamWeightDecayNameToIdx = {{"inputs",
{"weight_decay", 7}, {"weight_decay", 7},
{"grad", 8}}}, {"grad", 8}}},
{"outputs", {}}}; {"outputs", {}}};
const OptimParamNameToIndex kSGDNameToIdx = { const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {{kApplyMomentumOpName, kMomentumNameToIdx},
{"inputs", {{kWeight, 0}, {kGradient, 1}, {kLearningRate, 2}, {kAccumulation, 3}, {kMomentum, 4}, {kStat, 5}}}, {kFusedSparseAdamName, kSparseAdamNameToIdx},
{"outputs", {}}}; {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx},
{kApplyAdamOpName, kAdamNameToIdx},
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = { {"AdamWeightDecay", kAdamWeightDecayNameToIdx}};
{kApplyMomentumOpName, kMomentumNameToIdx}, {kFusedSparseAdamName, kSparseAdamNameToIdx},
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, {kApplyAdamOpName, kAdamNameToIdx},
{"AdamWeightDecay", kAdamWeightDecayNameToIdx}, {kSGDName, kSGDNameToIdx}};
constexpr uint32_t kLeaderServerRank = 0; constexpr uint32_t kLeaderServerRank = 0;
constexpr size_t kWorkerMgrThreadPoolSize = 32; constexpr size_t kWorkerMgrThreadPoolSize = 32;
@ -215,9 +213,12 @@ constexpr auto kCtxGetSecretsClientList = "get_secrets_client_list";
constexpr auto kCtxReconstructClientList = "reconstruct_client_list"; constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list"; constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list"; constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list";
constexpr auto kCtxClientListSigns = "client_list_signs";
constexpr auto kCtxClientKeyAttestation = "client_key_attestation";
constexpr auto kCtxGetKeysClientList = "get_keys_client_list"; constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size"; constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
constexpr auto kCtxCipherPrimer = "cipher_primer"; constexpr auto kCtxCipherPrimer = "cipher_primer";
constexpr auto kCurrentIteration = "current_iteration";
// This macro the current timestamp in milliseconds. // This macro the current timestamp in milliseconds.
#define CURRENT_TIME_MILLI \ #define CURRENT_TIME_MILLI \
@ -252,6 +253,14 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size
return addr; return addr;
} }
template <typename T>
inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) {
if (!json.contains(key)) {
MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump();
}
return json[key].get<T>();
}
// Definitions for Federated Learning. // Definitions for Federated Learning.
constexpr auto kNetworkError = "Cluster networking failed."; constexpr auto kNetworkError = "Cluster networking failed.";

View File

@ -26,7 +26,7 @@ bool ConsistentHashRing::Insert(uint32_t rank) {
MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank << ", hash value is " MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank << ", hash value is "
<< hash_value; << hash_value;
if (ring_.count(hash_value) != 0) { if (ring_.count(hash_value) != 0) {
MS_LOG(INFO) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring."; MS_LOG(WARNING) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring.";
continue; continue;
} }
ring_[hash_value] = rank; ring_[hash_value] = rank;
@ -37,7 +37,7 @@ bool ConsistentHashRing::Insert(uint32_t rank) {
bool ConsistentHashRing::Erase(uint32_t rank) { bool ConsistentHashRing::Erase(uint32_t rank) {
for (auto iterator = ring_.begin(); iterator != ring_.end();) { for (auto iterator = ring_.begin(); iterator != ring_.end();) {
if (iterator->second == rank) { if (iterator->second == rank) {
iterator = ring_.erase(iterator); (void)ring_.erase(iterator++);
} else { } else {
++iterator; ++iterator;
} }

View File

@ -329,6 +329,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
// Broadcast to all follower servers. // Broadcast to all follower servers.
for (uint32_t i = 1; i < server_num_; i++) { for (uint32_t i = 1; i < server_num_; i++) {
MS_LOG(INFO) << "Start sending first count event message to server " << i;
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
if (reason != nullptr) { if (reason != nullptr) {
@ -343,7 +344,9 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
return false; return false;
} }
// Leader server directly calls the callback. // Leader server directly calls the callback.
MS_LOG(INFO) << "Leader server call first count handler for " << name << "...";
counter_handlers_[name].first_count_handler(nullptr); counter_handlers_[name].first_count_handler(nullptr);
MS_LOG(INFO) << "First count handler for " << name << " is successfully called.";
return true; return true;
} }
@ -355,6 +358,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
// Broadcast to all follower servers. // Broadcast to all follower servers.
for (uint32_t i = 1; i < server_num_; i++) { for (uint32_t i = 1; i < server_num_; i++) {
MS_LOG(INFO) << "Start sending last count event message to server " << i;
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
if (reason != nullptr) { if (reason != nullptr) {
@ -369,7 +373,9 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
return false; return false;
} }
// Leader server directly calls the callback. // Leader server directly calls the callback.
MS_LOG(INFO) << "Leader server call last count handler for " << name << "...";
counter_handlers_[name].last_count_handler(nullptr); counter_handlers_[name].last_count_handler(nullptr);
MS_LOG(INFO) << "Last count handler for " << name << " is successfully called.";
return true; return true;
} }
} // namespace server } // namespace server

View File

@ -21,7 +21,6 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "utils/hash_map.h"
#include "proto/ps.pb.h" #include "proto/ps.pb.h"
#include "fl/server/common.h" #include "fl/server/common.h"
#include "ps/core/server_node.h" #include "ps/core/server_node.h"
@ -118,14 +117,14 @@ class DistributedCountService {
// Key: name, e.g, startFLJob, updateModel, push. // Key: name, e.g, startFLJob, updateModel, push.
// Value: a set of id without repeatation because each work may report multiple times. // Value: a set of id without repeatation because each work may report multiple times.
mindspore::HashMap<std::string, std::set<std::string>> global_current_count_; std::unordered_map<std::string, std::set<std::string>> global_current_count_;
// Key: name, e.g, StartFLJobCount. // Key: name, e.g, StartFLJobCount.
// Value: global threshold count in the server cluster dimension for this name. // Value: global threshold count in the server cluster dimension for this name.
mindspore::HashMap<std::string, size_t> global_threshold_count_; std::unordered_map<std::string, size_t> global_threshold_count_;
// First/last count event callbacks of the name. // First/last count event callbacks of the name.
mindspore::HashMap<std::string, CounterHandlers> counter_handlers_; std::unordered_map<std::string, CounterHandlers> counter_handlers_;
// Because the count is increased/queried conccurently, we must ensure the operations are threadsafe. // Because the count is increased/queried conccurently, we must ensure the operations are threadsafe.
std::unordered_map<std::string, std::mutex> mutex_; std::unordered_map<std::string, std::mutex> mutex_;

View File

@ -125,10 +125,6 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
return {}; return {};
} }
if (metadata_.count(name) == 0) {
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
return {};
}
uint32_t stored_rank = router_->Find(name); uint32_t stored_rank = router_->Find(name);
MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank; MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
if (local_rank_ == stored_rank) { if (local_rank_ == stored_rank) {

View File

@ -20,7 +20,6 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "utils/hash_map.h"
#include "proto/ps.pb.h" #include "proto/ps.pb.h"
#include "fl/server/common.h" #include "fl/server/common.h"
#include "ps/core/server_node.h" #include "ps/core/server_node.h"
@ -107,7 +106,7 @@ class DistributedMetadataStore {
// We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use. // We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use.
// Key: data name. // Key: data name.
// Value: ProtoBuffer Struct. // Value: ProtoBuffer Struct.
mindspore::HashMap<std::string, PBMetadata> metadata_; std::unordered_map<std::string, PBMetadata> metadata_;
// Because the metadata is read/written conccurently, we must ensure the operations are threadsafe. // Because the metadata is read/written conccurently, we must ensure the operations are threadsafe.
std::unordered_map<std::string, std::mutex> mutex_; std::unordered_map<std::string, std::mutex> mutex_;

View File

@ -65,47 +65,6 @@ bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
bool Executor::initialized() const { return initialized_; } bool Executor::initialized() const { return initialized_; }
bool Executor::HandlePush(const std::string &param_name, const UploadData &upload_data) {
MS_LOG(DEBUG) << "Do Push for parameter " << param_name;
if (param_aggrs_.count(param_name) == 0) {
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
return false;
}
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
// Push operation needs to wait until the pulling process is done.
while (!param_aggr->IsPullingDone()) {
lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
lock.lock();
}
// 1.Update data with the uploaded data of the worker.
if (!param_aggr->UpdateData(upload_data)) {
MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
return false;
}
// 2.Launch aggregation for this trainable parameter.
if (!param_aggr->LaunchAggregators()) {
MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
return false;
}
if (param_aggr->IsAggregationDone()) {
// 3.After the aggregation is done, optimize the trainable parameter.
if (!param_aggr->LaunchOptimizers()) {
MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed.";
return false;
}
// 4.Reset pulling and aggregation status after optimizing is done.
param_aggr->ResetPullingStatus();
param_aggr->ResetAggregationStatus();
}
return true;
}
bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData &upload_data) { bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData &upload_data) {
MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name; MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
if (param_aggrs_.count(param_name) == 0) { if (param_aggrs_.count(param_name) == 0) {
@ -131,32 +90,6 @@ bool Executor::HandleModelUpdate(const std::string &param_name, const UploadData
return true; return true;
} }
bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map) {
std::unique_lock<std::mutex> model_lock(model_mutex_);
for (const auto &trainable_param : feature_map) {
const std::string &param_name = trainable_param.first;
if (param_aggrs_.count(param_name) == 0) {
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
continue;
}
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
const UploadData &upload_data = trainable_param.second;
if (!param_aggr->UpdateData(upload_data)) {
MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
return false;
}
if (!param_aggr->LaunchAggregators()) {
MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
return false;
}
}
return true;
}
bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) { bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
for (const auto &trainable_param : feature_map) { for (const auto &trainable_param : feature_map) {
const std::string &param_name = trainable_param.first; const std::string &param_name = trainable_param.first;
@ -183,31 +116,6 @@ bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_ma
return true; return true;
} }
AddressPtr Executor::HandlePull(const std::string &param_name) {
MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
if (param_aggrs_.count(param_name) == 0) {
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
return nullptr;
}
std::mutex &mtx = parameter_mutex_[param_name];
std::unique_lock<std::mutex> lock(mtx);
auto &param_aggr = param_aggrs_[param_name];
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, nullptr);
// Pulling must wait until the optimizing process is done.
while (!param_aggr->IsOptimizingDone()) {
lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
lock.lock();
}
AddressPtr addr = param_aggr->Pull();
// If this Pull is the last one, reset pulling and optimizing status.
if (param_aggr->IsPullingDone()) {
param_aggr->ResetOptimizingStatus();
}
return addr;
}
std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> &param_names) { std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> &param_names) {
std::map<std::string, AddressPtr> weights; std::map<std::string, AddressPtr> weights;
for (const auto &param_name : param_names) { for (const auto &param_name : param_names) {
@ -298,7 +206,7 @@ bool Executor::unmasked() const {
if (encrypt_type == ps::kPWEncryptType) { if (encrypt_type == ps::kPWEncryptType) {
return unmasked_.load(); return unmasked_.load();
} else { } else {
// If the algorithm of pairwise encrypt is not enabled, consider_ unmasked flag as true. // If the algorithm of mind armour is not enabled, consider unmasked_ flag as true.
return true; return true;
} }
} }
@ -340,7 +248,7 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
param_aggrs_[param_name] = param_aggr; param_aggrs_[param_name] = param_aggr;
parameter_mutex_[param_name]; parameter_mutex_[param_name];
if (!param_aggr->Init(cnode, aggregation_count_)) { if (!param_aggr->Init(cnode, aggregation_count_)) {
MS_LOG(EXCEPTION) << "Initializing parameter aggregator for " << param_name << " failed."; MS_LOG(EXCEPTION) << "Initializing parameter aggregator for param_name " << param_name << " failed.";
return false; return false;
} }
MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success."; MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";

View File

@ -24,11 +24,11 @@
#include <vector> #include <vector>
#include <mutex> #include <mutex>
#include <condition_variable> #include <condition_variable>
#include "fl/server/common.h"
#include "fl/server/parameter_aggregator.h"
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
#include "fl/armour/cipher/cipher_unmask.h" #include "fl/armour/cipher/cipher_unmask.h"
#endif #endif
#include "fl/server/common.h"
#include "fl/server/parameter_aggregator.h"
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
@ -54,28 +54,13 @@ class Executor {
// After hyper-parameters are updated, some parameter aggregators should be reinitialized. // After hyper-parameters are updated, some parameter aggregators should be reinitialized.
bool ReInitForUpdatingHyperParams(size_t aggr_threshold); bool ReInitForUpdatingHyperParams(size_t aggr_threshold);
// Called in parameter server training mode to do Push operation.
// For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
// as completed.
bool HandlePush(const std::string &param_name, const UploadData &upload_data);
// Called in parameter server training mode to do Pull operation.
// Returns the value of parameter param_name.
// HandlePull method must be called the same times as HandlePush is called before it's considered as
// completed.
AddressPtr HandlePull(const std::string &param_name);
// Called in federated learning training mode. Update value for parameter param_name. // Called in federated learning training mode. Update value for parameter param_name.
bool HandleModelUpdate(const std::string &param_name, const UploadData &upload_data); bool HandleModelUpdate(const std::string &param_name, const UploadData &upload_data);
// Called in asynchronous federated learning training mode. Update current model with the new feature map // Forcibly overwrite specific weights in overwriteWeights message.
// asynchronously.
bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
// Overwrite the weights in server using pushed feature map.
bool HandlePushWeight(const std::map<std::string, Address> &feature_map); bool HandlePushWeight(const std::map<std::string, Address> &feature_map);
// Returns multiple trainable parameters passed by weight_names. // Returns value for multiple trainable parameters passed by weight_names.
std::map<std::string, AddressPtr> HandlePullWeight(const std::vector<std::string> &param_names); std::map<std::string, AddressPtr> HandlePullWeight(const std::vector<std::string> &param_names);
// Reset the aggregation status for all aggregation kernels in the server. // Reset the aggregation status for all aggregation kernels in the server.
@ -135,7 +120,7 @@ class Executor {
armour::CipherUnmask cipher_unmask_; armour::CipherUnmask cipher_unmask_;
#endif #endif
// The flag represents the unmasking status. // The flag refers to the unmasking status
std::atomic<bool> unmasked_; std::atomic<bool> unmasked_;
}; };
} // namespace server } // namespace server

View File

@ -16,8 +16,8 @@
#include "fl/server/iteration.h" #include "fl/server/iteration.h"
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include <string>
#include <numeric> #include <numeric>
#include "fl/server/model_store.h" #include "fl/server/model_store.h"
#include "fl/server/server.h" #include "fl/server/server.h"
@ -38,6 +38,7 @@ Iteration::~Iteration() {
void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) { void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
MS_EXCEPTION_IF_NULL(communicator); MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator; communicator_ = communicator;
MS_EXCEPTION_IF_NULL(communicator_);
communicator_->RegisterMsgCallBack("syncIteration", communicator_->RegisterMsgCallBack("syncIteration",
std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1)); std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
communicator_->RegisterMsgCallBack( communicator_->RegisterMsgCallBack(
@ -54,10 +55,10 @@ void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommu
void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) { void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) {
MS_EXCEPTION_IF_NULL(server_node); MS_EXCEPTION_IF_NULL(server_node);
server_node_ = server_node; server_node_ = server_node;
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning), server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::UserDefineEvent::kIterationRunning),
std::bind(&Iteration::HandleIterationRunningEvent, this)); std::bind(&Iteration::ProcessIterationRunningEvent, this));
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted), server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::UserDefineEvent::kIterationCompleted),
std::bind(&Iteration::HandleIterationCompletedEvent, this)); std::bind(&Iteration::ProcessIterationEndEvent, this));
} }
void Iteration::AddRound(const std::shared_ptr<Round> &round) { void Iteration::AddRound(const std::shared_ptr<Round> &round) {
@ -97,6 +98,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
if (!move_to_next_thread_running_.load()) { if (!move_to_next_thread_running_.load()) {
break; break;
} }
lock.unlock();
MoveToNextIteration(is_last_iteration_valid_, move_to_next_reason_); MoveToNextIteration(is_last_iteration_valid_, move_to_next_reason_);
} }
}); });
@ -113,6 +115,7 @@ void Iteration::NotifyNext(bool is_last_iter_valid, const std::string &reason) {
} }
void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &reason) { void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Notify cluster starts to proceed to next iteration. Iteration is " << iteration_num_ MS_LOG(INFO) << "Notify cluster starts to proceed to next iteration. Iteration is " << iteration_num_
<< " validation is " << is_last_iter_valid << ". Reason: " << reason; << " validation is " << is_last_iter_valid << ". Reason: " << reason;
if (IsMoveToNextIterRequestReentrant(iteration_num_)) { if (IsMoveToNextIterRequestReentrant(iteration_num_)) {
@ -147,7 +150,13 @@ void Iteration::SetIterationRunning() {
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == kLeaderServerRank) { if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state. // This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning)); server_node_->BroadcastEvent(static_cast<uint32_t>(ps::UserDefineEvent::kIterationRunning));
if (server_recovery_ != nullptr) {
// Save data to the persistent storage in case the recovery happens at the beginning.
if (!server_recovery_->Save(iteration_num_)) {
MS_LOG(WARNING) << "Save recovery data failed.";
}
}
} }
std::unique_lock<std::mutex> lock(iteration_state_mtx_); std::unique_lock<std::mutex> lock(iteration_state_mtx_);
@ -155,12 +164,12 @@ void Iteration::SetIterationRunning() {
start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()); start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
} }
void Iteration::SetIterationCompleted() { void Iteration::SetIterationEnd() {
MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes."; MS_LOG(INFO) << "Iteration " << iteration_num_ << " ends.";
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == kLeaderServerRank) { if (server_node_->rank_id() == kLeaderServerRank) {
// This event helps worker/server to be consistent in iteration state. // This event helps worker/server to be consistent in iteration state.
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted)); server_node_->BroadcastEvent(static_cast<uint32_t>(ps::UserDefineEvent::kIterationCompleted));
} }
std::unique_lock<std::mutex> lock(iteration_state_mtx_); std::unique_lock<std::mutex> lock(iteration_state_mtx_);
@ -274,7 +283,7 @@ bool Iteration::DisableServerInstance(std::string *result) {
instance_state_ = InstanceState::kDisable; instance_state_ = InstanceState::kDisable;
if (!ForciblyMoveToNextIteration()) { if (!ForciblyMoveToNextIteration()) {
*result = "Disabling instance failed. Can't drop current iteration and move to the next."; *result = "Disabling instance failed. Can't drop current iteration and move to the next.";
MS_LOG(ERROR) << *result; MS_LOG(ERROR) << result;
return false; return false;
} }
*result = "Disabling FL-Server succeeded."; *result = "Disabling FL-Server succeeded.";
@ -295,6 +304,11 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
return false; return false;
} }
if (iteration_num_ == 1) {
MS_LOG(INFO) << "This is just the first iteration.";
return true;
}
// Start new server instance. // Start new server instance.
is_instance_being_updated_ = true; is_instance_being_updated_ = true;
@ -312,7 +326,7 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
ModelStore::GetInstance().Reset(); ModelStore::GetInstance().Reset();
if (metrics_ != nullptr) { if (metrics_ != nullptr) {
if (!metrics_->Clear()) { if (!metrics_->Clear()) {
MS_LOG(WARNING) << "Clear metrics fil failed."; MS_LOG(WARNING) << "Clear metrics file failed.";
} }
} }
@ -340,6 +354,16 @@ void Iteration::WaitAllRoundsFinish() const {
} }
} }
void Iteration::set_recovery_handler(const std::shared_ptr<ServerRecovery> &server_recovery) {
MS_EXCEPTION_IF_NULL(server_recovery);
server_recovery_ = server_recovery;
}
bool Iteration::SyncAfterRecovery(uint64_t) {
NotifyNext(false, "Move to next iteration after recovery.");
return true;
}
bool Iteration::SyncIteration(uint32_t rank) { bool Iteration::SyncIteration(uint32_t rank) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false); MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
SyncIterationRequest sync_iter_req; SyncIterationRequest sync_iter_req;
@ -348,7 +372,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr; std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration, if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration,
&sync_iter_rsp_msg)) { &sync_iter_rsp_msg)) {
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed."; MS_LOG(ERROR) << "Sending sync iter message to leader server failed.";
return false; return false;
} }
@ -356,8 +380,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
SyncIterationResponse sync_iter_rsp; SyncIterationResponse sync_iter_rsp;
(void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size())); (void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size()));
iteration_num_ = sync_iter_rsp.iteration(); iteration_num_ = sync_iter_rsp.iteration();
MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is " MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is " << iteration_num_;
<< sync_iter_rsp.iteration();
return true; return true;
} }
@ -496,7 +519,9 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::
void Iteration::PrepareForNextIter() { void Iteration::PrepareForNextIter() {
MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode."; MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode.";
Server::GetInstance().SwitchToSafeMode(); Server::GetInstance().SwitchToSafeMode();
MS_LOG(INFO) << "Start waiting for rounds to finish.";
WaitAllRoundsFinish(); WaitAllRoundsFinish();
MS_LOG(INFO) << "End waiting for rounds to finish.";
} }
bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) { bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
@ -627,12 +652,22 @@ void Iteration::EndLastIter() {
pinned_iter_num_ = 0; pinned_iter_num_ = 0;
lock.unlock(); lock.unlock();
SetIterationCompleted(); SetIterationEnd();
if (!SummarizeIteration()) { if (!SummarizeIteration()) {
MS_LOG(WARNING) << "Summarizing iteration data failed."; MS_LOG(WARNING) << "Summarizing iteration data failed.";
} }
iteration_num_++; iteration_num_++;
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
if (server_node_->rank_id() == kLeaderServerRank) {
// Save current iteration number for recovery.
MS_ERROR_IF_NULL_WO_RET_VAL(server_recovery_);
if (!server_recovery_->Save(iteration_num_)) {
MS_LOG(WARNING) << "Can't save current iteration number into persistent storage.";
}
}
Server::GetInstance().CancelSafeMode(); Server::GetInstance().CancelSafeMode();
iteration_state_cv_.notify_all(); iteration_state_cv_.notify_all();
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n"; MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";

View File

@ -25,6 +25,7 @@
#include "fl/server/round.h" #include "fl/server/round.h"
#include "fl/server/local_meta_store.h" #include "fl/server/local_meta_store.h"
#include "fl/server/iteration_metrics.h" #include "fl/server/iteration_metrics.h"
#include "fl/server/server_recovery.h"
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
@ -52,7 +53,7 @@ class Iteration {
// Register callbacks for other servers to synchronize iteration information from leader server. // Register callbacks for other servers to synchronize iteration information from leader server.
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator); void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
// Register event callbacks for iteration state synchronization. // Register event callback for iteration state synchronization.
void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node); void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node);
// Add a round for the iteration. This method will be called multiple times for each round. // Add a round for the iteration. This method will be called multiple times for each round.
@ -70,14 +71,14 @@ class Iteration {
// This method will control servers to proceed to next iteration. // This method will control servers to proceed to next iteration.
// There's communication between leader and follower servers in this method. // There's communication between leader and follower servers in this method.
// The server moves to next iteration only after the last round finishes or the time expires. // The server moves to the next iteration only after the last round finishes or the timer expires.
void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason); void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
// Set current iteration state to running and trigger events about kIterationRunning. // Set current iteration state to running and trigger the event.
void SetIterationRunning(); void SetIterationRunning();
// Set current iteration state to completed and trigger the event about kIterationCompleted. // Set current iteration state to end and trigger the event.
void SetIterationCompleted(); void SetIterationEnd();
// The barrier function for elastic scaling. The scaling out/in operation should be done only after this iteration is // The barrier function for elastic scaling. The scaling out/in operation should be done only after this iteration is
// completed. // completed.
@ -118,6 +119,12 @@ class Iteration {
// Need to wait all the rounds to finish before proceed to next iteration. // Need to wait all the rounds to finish before proceed to next iteration.
void WaitAllRoundsFinish() const; void WaitAllRoundsFinish() const;
// Set server's recovery handler.
void set_recovery_handler(const std::shared_ptr<ServerRecovery> &server_recovery);
// Synchronize server iteration after another server's recovery is completed.
bool SyncAfterRecovery(uint64_t iteration_num);
// The round kernels whose Launch method has not returned yet. // The round kernels whose Launch method has not returned yet.
std::atomic_uint32_t running_round_num_; std::atomic_uint32_t running_round_num_;
@ -150,10 +157,10 @@ class Iteration {
Iteration &operator=(const Iteration &) = delete; Iteration &operator=(const Iteration &) = delete;
// The server does not need to handle the iteration events for now. // The server does not need to handle the iteration events for now.
void HandleIterationRunningEvent() {} void ProcessIterationRunningEvent() {}
void HandleIterationCompletedEvent() {} void ProcessIterationEndEvent() {}
// Synchronize iteration form the leader server(Rank 0). // Synchronize iteration from the leader server(Rank 0).
bool SyncIteration(uint32_t rank); bool SyncIteration(uint32_t rank);
void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message); void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
@ -165,13 +172,13 @@ class Iteration {
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason); bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message); void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode. // Step 2: leader server broadcasts to all follower servers to prepare for next iteration and switch to safemode..
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason); bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message); void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// The server prepare for the next iteration. This method will switch the server to safemode. // The server prepare for the next iteration. This method will switch the server to safemode.
void PrepareForNextIter(); void PrepareForNextIter();
// Step 3: leader server broadcast to all follower servers to move to next iteration. // Step 3: leader server broadcasts to all follower servers to move to next iteration.
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason); bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message); void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Move to next iteration. Store last iterations model and reset all the rounds. // Move to next iteration. Store last iterations model and reset all the rounds.
@ -201,6 +208,9 @@ class Iteration {
// All the rounds in the server. // All the rounds in the server.
std::vector<std::shared_ptr<Round>> rounds_; std::vector<std::shared_ptr<Round>> rounds_;
// The recovery object for server.
std::shared_ptr<ServerRecovery> server_recovery_;
// The iteration is either running or completed at any time. // The iteration is either running or completed at any time.
std::mutex iteration_state_mtx_; std::mutex iteration_state_mtx_;
std::condition_variable iteration_state_cv_; std::condition_variable iteration_state_cv_;

View File

@ -17,6 +17,7 @@
#include "fl/server/iteration_metrics.h" #include "fl/server/iteration_metrics.h"
#include <string> #include <string>
#include <fstream> #include <fstream>
#include "utils/file_utils.h"
#include "debug/common.h" #include "debug/common.h"
#include "ps/constants.h" #include "ps/constants.h"
@ -68,7 +69,7 @@ bool IterationMetrics::Initialize() {
} }
bool IterationMetrics::Summarize() { bool IterationMetrics::Summarize() {
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out); metrics_file_.open(metrics_file_path_, std::ios::out | std::ios::app);
if (!metrics_file_.is_open()) { if (!metrics_file_.is_open()) {
MS_LOG(ERROR) << "The metrics file is not opened."; MS_LOG(ERROR) << "The metrics file is not opened.";
return false; return false;

View File

@ -39,20 +39,11 @@ constexpr auto kRejectedClientNum = "rejectedClientNum";
constexpr auto kMetricsAuc = "metricsAuc"; constexpr auto kMetricsAuc = "metricsAuc";
constexpr auto kMetricsLoss = "metricsLoss"; constexpr auto kMetricsLoss = "metricsLoss";
constexpr auto kIterExecutionTime = "iterationExecutionTime"; constexpr auto kIterExecutionTime = "iterationExecutionTime";
constexpr auto kMetrics = "metrics";
const std::map<InstanceState, std::string> kInstanceStateName = { const std::map<InstanceState, std::string> kInstanceStateName = {
{InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}}; {InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}};
template <typename T>
inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) {
if (!json.contains(key)) {
MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump();
}
return json[key].get<T>();
}
constexpr auto kMetrics = "metrics";
class IterationMetrics { class IterationMetrics {
public: public:
explicit IterationMetrics(const std::string &config_file) explicit IterationMetrics(const std::string &config_file)

View File

@ -19,6 +19,13 @@
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
IterationTimer::~IterationTimer() {
running_ = false;
if (monitor_thread_.joinable()) {
monitor_thread_.join();
}
}
void IterationTimer::Start(const std::chrono::milliseconds &duration) { void IterationTimer::Start(const std::chrono::milliseconds &duration) {
if (running_.load()) { if (running_.load()) {
MS_LOG(WARNING) << "The timer already started."; MS_LOG(WARNING) << "The timer already started.";
@ -50,7 +57,7 @@ void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) {
return; return;
} }
bool IterationTimer::IsTimeOut(const std::chrono::milliseconds &timestamp) const { bool IterationTimer::IsTimeOut(const std::chrono::milliseconds &timestamp) {
return timestamp > end_time_ ? true : false; return timestamp > end_time_ ? true : false;
} }

View File

@ -30,7 +30,7 @@ namespace server {
class IterationTimer { class IterationTimer {
public: public:
IterationTimer() : running_(false), end_time_(0) {} IterationTimer() : running_(false), end_time_(0) {}
~IterationTimer() = default; ~IterationTimer();
// Start timing. The timer will stop after parameter 'duration' milliseconds. // Start timing. The timer will stop after parameter 'duration' milliseconds.
void Start(const std::chrono::milliseconds &duration); void Start(const std::chrono::milliseconds &duration);
@ -42,7 +42,7 @@ class IterationTimer {
void SetTimeOutCallBack(const TimeOutCb &timeout_cb); void SetTimeOutCallBack(const TimeOutCb &timeout_cb);
// Judge whether current timestamp is out of time window's range since the Start function is called. // Judge whether current timestamp is out of time window's range since the Start function is called.
bool IsTimeOut(const std::chrono::milliseconds &timestamp) const; bool IsTimeOut(const std::chrono::milliseconds &timestamp);
// Judge whether the timer is keeping timing. // Judge whether the timer is keeping timing.
bool IsRunning() const; bool IsRunning() const;

View File

@ -36,59 +36,10 @@ class DenseGradAccumKernel : public AggregationKernel {
DenseGradAccumKernel() = default; DenseGradAccumKernel() = default;
~DenseGradAccumKernel() override = default; ~DenseGradAccumKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override { void InitKernel(const CNodePtr &) override { return; }
MS_EXCEPTION_IF_NULL(kernel_node);
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
kNameToIdxMap.at(cnode_name).at("inputs").count("grad") == 0) {
MS_LOG(EXCEPTION) << "Can't find index info of grad for kernel " << cnode_name;
return;
}
size_t cnode_grad_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("grad");
std::vector<size_t> grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_grad_idx);
size_t grad_size = std::accumulate(grad_shape.begin(), grad_shape.end(), sizeof(T), std::multiplies<size_t>());
input_size_list_.push_back(grad_size);
size_t new_grad_size = grad_size;
input_size_list_.push_back(new_grad_size);
GenerateReuseKernelNodeInfo();
return;
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &) override { const std::vector<AddressPtr> &) override {
if (inputs.size() != kDenseGradAccumKernelInputsNum) {
MS_LOG(ERROR) << "The inputs number of DenseGradAccumKernel should be 2, but got " << inputs.size();
return false;
}
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0], false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1], false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false);
if (accum_count_ == 0) {
int ret = memset_s(inputs[0]->addr, inputs[0]->size, 0x00, inputs[0]->size);
if (ret != 0) {
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
return false;
}
}
T *grad_addr = reinterpret_cast<T *>(inputs[0]->addr);
T *new_grad_addr = reinterpret_cast<T *>(inputs[1]->addr);
for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) {
grad_addr[i] += new_grad_addr[i];
}
accum_count_++;
if (accum_count_ > done_count_) {
MS_LOG(ERROR) << "accum_count_ should not be greater than done_count_ " << done_count_;
return false;
}
if (accum_count_ == done_count_) {
for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) {
grad_addr[i] /= done_count_;
}
}
return true; return true;
} }

View File

@ -108,8 +108,8 @@ class FedAvgKernel : public AggregationKernel {
done_ = true; done_ = true;
return; return;
}; };
GenerateReuseKernelNodeInfo();
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_}); DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_});
GenerateReuseKernelNodeInfo();
return; return;
} }
@ -119,14 +119,9 @@ class FedAvgKernel : public AggregationKernel {
MS_LOG(ERROR) << "The inputs number of FedAvgKernel should be 4, but got " << inputs.size(); MS_LOG(ERROR) << "The inputs number of FedAvgKernel should be 4, but got " << inputs.size();
return false; return false;
} }
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0], false); for (size_t i = 0; i < inputs.size(); i++) {
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1], false); MS_ERROR_IF_NULL_W_RET_VAL(inputs[i]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[2], false); }
MS_ERROR_IF_NULL_W_RET_VAL(inputs[3], false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[2]->addr, false);
MS_ERROR_IF_NULL_W_RET_VAL(inputs[3]->addr, false);
std::unique_lock<std::mutex> lock(weight_mutex_); std::unique_lock<std::mutex> lock(weight_mutex_);
// The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication // The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication
@ -181,7 +176,7 @@ class FedAvgKernel : public AggregationKernel {
bool ReInitForUpdatingHyperParams(size_t aggr_threshold) override { bool ReInitForUpdatingHyperParams(size_t aggr_threshold) override {
done_count_ = aggr_threshold; done_count_ = aggr_threshold;
if (!DistributedCountService::GetInstance().ReInitCounter(name_, done_count_)) { if (!DistributedCountService::GetInstance().ReInitCounter(name_, done_count_)) {
MS_LOG(ERROR) << "Reinitializing counter for " << name_ << " failed."; MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed.";
return false; return false;
} }
return true; return true;

View File

@ -21,7 +21,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include "utils/hash_map.h" #include <unordered_map>
#include "fl/server/common.h" #include "fl/server/common.h"
#include "fl/server/kernel/params_info.h" #include "fl/server/kernel/params_info.h"
@ -83,7 +83,7 @@ class KernelFactory {
// Generally, a server kernel can correspond to several ParamsInfo which is registered by the method 'Register' in // Generally, a server kernel can correspond to several ParamsInfo which is registered by the method 'Register' in
// server kernel's *.cc files. // server kernel's *.cc files.
mindspore::HashMap<std::string, std::vector<std::pair<ParamsInfo, C>>> name_to_creator_map_; std::unordered_map<std::string, std::vector<std::pair<ParamsInfo, C>>> name_to_creator_map_;
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -21,49 +21,7 @@ namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
bool OptimizerKernelFactory::Matched(const ParamsInfo &params_info, const CNodePtr &kernel_node) { bool OptimizerKernelFactory::Matched(const ParamsInfo &, const CNodePtr &) { return true; }
MS_EXCEPTION_IF_NULL(kernel_node);
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
if (kNameToIdxMap.count(cnode_name) == 0) {
MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name;
return false;
}
auto input_name_to_idx = kNameToIdxMap.at(cnode_name).at("inputs");
size_t input_num = params_info.inputs_num();
for (size_t i = 0; i < input_num; i++) {
auto one_input_name_type = params_info.inputs_name_type(i);
std::string name = one_input_name_type.first;
if (input_name_to_idx.count(name) == 0) {
MS_LOG(EXCEPTION) << cnode_name << " does not have input named " << name;
return false;
}
size_t input_idx = input_name_to_idx.at(name);
TypeId kernel_node_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_idx);
TypeId registered_input_type = one_input_name_type.second;
if (registered_input_type != kernel_node_input_type) {
return false;
}
}
auto output_name_to_idx = kNameToIdxMap.at(cnode_name).at("outputs");
size_t output_num = params_info.outputs_num();
for (size_t i = 0; i < output_num; i++) {
auto one_output_name_type = params_info.outputs_name_type(i);
std::string name = one_output_name_type.first;
if (output_name_to_idx.count(name) == 0) {
MS_LOG(EXCEPTION) << cnode_name << " does not have output named " << name;
return false;
}
size_t output_idx = output_name_to_idx.at(name);
TypeId kernel_node_output_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_idx);
TypeId registered_output_type = one_output_name_type.second;
if (registered_output_type != kernel_node_output_type) {
return false;
}
}
return true;
}
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server
} // namespace fl } // namespace fl

View File

@ -19,6 +19,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include "schema/cipher_generated.h" #include "schema/cipher_generated.h"
namespace mindspore { namespace mindspore {
@ -29,16 +30,50 @@ void ClientListKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
} }
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_init_ = &armour::CipherInit::GetInstance(); cipher_init_ = &armour::CipherInit::GetInstance();
} }
sigVerifyResult ClientListKernel::VerifySignature(const schema::GetClientList *get_clients_req) {
std::string fl_id = get_clients_req->fl_id()->str();
std::string timestamp = get_clients_req->timestamp()->str();
int iteration = get_clients_req->iteration();
std::string iter_str = std::to_string(iteration);
auto fbs_signature = get_clients_req->signature();
std::vector<unsigned char> signature;
if (fbs_signature == nullptr) {
MS_LOG(ERROR) << "signature in get_clients_req is nullptr";
return sigVerifyResult::FAILED;
}
signature.assign(fbs_signature->begin(), fbs_signature->end());
std::map<std::string, std::string> key_attestations;
const fl::PBMetadata &key_attestations_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
auto iter = key_attestation_pb.key_attestations().begin();
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
}
if (key_attestations.find(fl_id) == key_attestations.end()) {
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
return sigVerifyResult::FAILED;
}
std::vector<unsigned char> src_data;
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
mindspore::ps::server::CertVerify certVerify;
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
return sigVerifyResult::FAILED;
}
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
return sigVerifyResult::TIMEOUT;
}
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
return sigVerifyResult::PASSED;
}
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req, bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
const std::shared_ptr<server::FBBuilder> &fbb) { const std::shared_ptr<server::FBBuilder> &fbb) {
std::vector<string> client_list; std::vector<string> client_list;
@ -48,7 +83,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
if (!LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) { if (!LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
MS_LOG(ERROR) << "update_model_client_threshold is not set."; MS_LOG(ERROR) << "update_model_client_threshold is not set.";
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_threshold is not set.", BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_threshold is not set.",
empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false; return false;
} }
uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld); uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
@ -57,11 +92,11 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
for (size_t i = 0; i < IntToSize(client_list_pb.fl_id_size()); ++i) { for (size_t i = 0; i < IntToSize(client_list_pb.fl_id_size()); ++i) {
client_list.push_back(client_list_pb.fl_id(SizeToInt(i))); client_list.push_back(client_list_pb.fl_id(SizeToInt(i)));
} }
if (static_cast<uint64_t>(client_list.size()) < update_model_client_needed) { if (client_list.size() < update_model_client_needed) {
MS_LOG(INFO) << "The server is not ready. update_model_client_needed: " << update_model_client_needed; MS_LOG(INFO) << "The server is not ready. update_model_client_needed: " << update_model_client_needed;
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size(); MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", empty_client_list, BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false; return false;
} }
@ -69,7 +104,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
std::string reason = "fl_id: " + fl_id + " is not in the update_model_clients"; std::string reason = "fl_id: " + fl_id + " is not in the update_model_clients";
MS_LOG(INFO) << reason; MS_LOG(INFO) << reason;
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, empty_client_list, BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false; return false;
} }
@ -79,34 +114,27 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
std::string reason = "update get update model clients failed"; std::string reason = "update get update model clients failed";
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, reason, empty_client_list, BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return false; return false;
} }
if (!DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str())) { if (!DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str())) {
std::string reason = "Counting for get user list request failed. Please retry later."; std::string reason = "Counting for get user list request failed. Please retry later.";
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list, BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
MS_LOG(INFO) << "send clients_list succeed!";
MS_LOG(INFO) << "UpdateModel client list: ";
for (size_t i = 0; i < client_list.size(); ++i) {
MS_LOG(INFO) << " fl_id : " << client_list[i];
}
MS_LOG(INFO) << "update_model_client_needed: " << update_model_client_needed; MS_LOG(INFO) << "update_model_client_needed: " << update_model_client_needed;
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list, BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
return true; return true;
} }
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); MS_LOG(INFO) << "Launching ClientListKernel, Iteration number is " << iter_num;
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
clock_t start_time = clock();
if (inputs.size() != 1 || outputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid."; std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
@ -121,26 +149,66 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
return false; return false;
} }
std::vector<string> client_list; std::vector<string> client_list;
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::GetClientList>()) {
std::string reason = "The schema of GetClientList is invalid.";
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data); const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
if (get_clients_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for GetClientList.";
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(get_clients_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, client_list,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::PASSED) {
MS_LOG(INFO) << "verify signature passed!";
}
}
size_t iter_client = IntToSize(get_clients_req->iteration()); size_t iter_client = IntToSize(get_clients_req->iteration());
if (iter_num != iter_client) { if (iter_num != iter_client) {
MS_LOG(ERROR) << "client list iteration number is invalid: server now iteration is " << iter_num MS_LOG(ERROR) << "client list iteration number is invalid: server now iteration is " << iter_num
<< ". client request iteration is " << iter_client; << ". client request iteration is " << iter_client;
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list, BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetClientList is enough."; MS_LOG(WARNING) << "Current amount for GetClientList is enough.";
} }
(void)DealClient(iter_num, get_clients_req, fbb); if (!DealClient(iter_num, get_clients_req, fbb)) {
MS_LOG(WARNING) << "Get Client List not ready.";
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
return true; return true;
} // namespace fl } // namespace fl
@ -153,28 +221,28 @@ bool ClientListKernel::Reset() {
return true; return true;
} }
void ClientListKernel::BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &client_list_resp_builder, void ClientListKernel::BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &fbb,
const schema::ResponseCode retcode, const string &reason, const schema::ResponseCode retcode, const string &reason,
std::vector<std::string> clients, const string &next_req_time, std::vector<std::string> clients, const string &next_req_time,
const int iteration) { const size_t iteration) {
auto rsp_reason = client_list_resp_builder->CreateString(reason); auto rsp_reason = fbb->CreateString(reason);
auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time); auto rsp_next_req_time = fbb->CreateString(next_req_time);
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector; std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
for (auto client : clients) { for (auto client : clients) {
auto client_fb = client_list_resp_builder->CreateString(client); auto client_fb = fbb->CreateString(client);
clients_vector.push_back(client_fb); clients_vector.push_back(client_fb);
MS_LOG(WARNING) << "update client list: "; MS_LOG(WARNING) << "update client list: ";
MS_LOG(WARNING) << client; MS_LOG(WARNING) << client;
} }
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector); auto clients_fb = fbb->CreateVector(clients_vector);
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get())); schema::ReturnClientListBuilder rsp_builder(*(fbb.get()));
rsp_builder.add_retcode(retcode); rsp_builder.add_retcode(SizeToInt(retcode));
rsp_builder.add_reason(rsp_reason); rsp_builder.add_reason(rsp_reason);
rsp_builder.add_clients(clients_fb); rsp_builder.add_clients(clients_fb);
rsp_builder.add_iteration(iteration); rsp_builder.add_iteration(SizeToInt(iteration));
rsp_builder.add_next_req_time(rsp_next_req_time); rsp_builder.add_next_req_time(rsp_next_req_time);
auto rsp_exchange_keys = rsp_builder.Finish(); auto rsp_exchange_keys = rsp_builder.Finish();
client_list_resp_builder->Finish(rsp_exchange_keys); fbb->Finish(rsp_exchange_keys);
return; return;
} }

View File

@ -29,6 +29,9 @@ namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// results of signature verification
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
class ClientListKernel : public RoundKernel { class ClientListKernel : public RoundKernel {
public: public:
ClientListKernel() = default; ClientListKernel() = default;
@ -37,12 +40,13 @@ class ClientListKernel : public RoundKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
bool Reset() override; bool Reset() override;
void BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &client_list_resp_builder, void BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const schema::ResponseCode retcode, const string &reason, std::vector<std::string> clients, const string &reason, std::vector<std::string> clients, const string &next_req_time,
const string &next_req_time, const int iteration); const size_t iteration);
private: private:
armour::CipherInit *cipher_init_; armour::CipherInit *cipher_init_;
sigVerifyResult VerifySignature(const schema::GetClientList *get_clients_req);
bool DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req, bool DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
const std::shared_ptr<server::FBBuilder> &fbb); const std::shared_ptr<server::FBBuilder> &fbb);
Executor *executor_; Executor *executor_;

View File

@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <memory> #include <memory>
#include <map>
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
@ -27,24 +28,15 @@ void ExchangeKeysKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
} }
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_key_ = &armour::CipherKeys::GetInstance(); cipher_key_ = &armour::CipherKeys::GetInstance();
} }
bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num) { bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const size_t iter_num) {
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
std::string reason = "Current amount for exchangeKey is enough. Please retry later."; std::string reason = "Current amount for exchangeKey is enough. Please retry later.";
cipher_key_->BuildExchangeKeysRsp( cipher_key_->BuildExchangeKeysRsp(
fbb, schema::ResponseCode_OutOfTime, reason, fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
IntToSize(iter_num));
MS_LOG(WARNING) << reason; MS_LOG(WARNING) << reason;
return true; return true;
} }
@ -53,30 +45,84 @@ bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBB
bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestExchangeKeys *exchange_keys_req, const schema::RequestExchangeKeys *exchange_keys_req,
const int iter_num) { const size_t iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(exchange_keys_req, false); MS_ERROR_IF_NULL_W_RET_VAL(exchange_keys_req, false);
if (!DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str())) { if (!DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str())) {
std::string reason = "Counting for exchange kernel request failed. Please retry later."; std::string reason = "Counting for exchange kernel request failed. Please retry later.";
cipher_key_->BuildExchangeKeysRsp( cipher_key_->BuildExchangeKeysRsp(
fbb, schema::ResponseCode_OutOfTime, reason, fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
IntToSize(iter_num));
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
return true; return true;
} }
sigVerifyResult ExchangeKeysKernel::VerifySignature(const schema::RequestExchangeKeys *exchange_keys_req) {
std::string fl_id = exchange_keys_req->fl_id()->str();
std::string timestamp = exchange_keys_req->timestamp()->str();
int iteration = exchange_keys_req->iteration();
std::string iter_str = std::to_string(iteration);
auto fbs_signature = exchange_keys_req->signature();
std::vector<unsigned char> signature;
if (fbs_signature == nullptr) {
MS_LOG(ERROR) << "signature in exchange_keys_req is nullptr";
return sigVerifyResult::FAILED;
}
signature.assign(fbs_signature->begin(), fbs_signature->end());
std::map<std::string, std::string> key_attestations;
const fl::PBMetadata &key_attestations_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
auto iter = key_attestation_pb.key_attestations().begin();
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
}
if (key_attestations.find(fl_id) == key_attestations.end()) {
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
return sigVerifyResult::FAILED;
}
auto fbs_cpk = exchange_keys_req->c_pk();
auto fbs_spk = exchange_keys_req->s_pk();
if (fbs_cpk == nullptr || fbs_spk == nullptr) {
MS_LOG(ERROR) << "public key from exchange_keys_req is null";
return sigVerifyResult::FAILED;
}
size_t spk_len = fbs_spk->size();
size_t cpk_len = fbs_cpk->size();
std::vector<uint8_t> cpk(cpk_len);
std::vector<uint8_t> spk(spk_len);
bool ret_create_code_cpk = mindspore::armour::CreateArray<uint8_t>(&cpk, *fbs_cpk);
bool ret_create_code_spk = mindspore::armour::CreateArray<uint8_t>(&spk, *fbs_spk);
if (!(ret_create_code_cpk && ret_create_code_spk)) {
MS_LOG(ERROR) << "create array for public keys failed";
return sigVerifyResult::FAILED;
}
std::vector<unsigned char> src_data;
(void)src_data.insert(src_data.end(), cpk.begin(), cpk.end());
(void)src_data.insert(src_data.end(), spk.begin(), spk.end());
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
mindspore::ps::server::CertVerify certVerify;
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
return sigVerifyResult::FAILED;
}
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
return sigVerifyResult::TIMEOUT;
}
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
return sigVerifyResult::PASSED;
}
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching ExchangeKey kernel.";
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); MS_LOG(INFO) << "Launching ExchangeKey kernel, ITERATION NUMBER IS : " << iter_num;
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is " bool response = false;
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1 || outputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid."; std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
@ -91,11 +137,56 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
return false; return false;
} }
if (ReachThresholdForExchangeKeys(fbb, SizeToInt(iter_num))) { if (ReachThresholdForExchangeKeys(fbb, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::RequestExchangeKeys>()) {
std::string reason = "The schema of RequestExchangeKeys is invalid.";
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data); const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
if (exchange_keys_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for ExchangeKeys.";
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(exchange_keys_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::PASSED) {
MS_LOG(INFO) << "verify signature passed!";
}
}
size_t iter_client = IntToSize(exchange_keys_req->iteration()); size_t iter_client = IntToSize(exchange_keys_req->iteration());
if (iter_num != iter_client) { if (iter_num != iter_client) {
MS_LOG(ERROR) << "ExchangeKeys iteration number is invalid: server now iteration is " << iter_num MS_LOG(ERROR) << "ExchangeKeys iteration number is invalid: server now iteration is " << iter_num
@ -107,19 +198,16 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
} }
response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb); response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
if (!response) { if (!response) {
MS_LOG(WARNING) << "update exchange keys is failed."; MS_LOG(ERROR) << "update exchange keys is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
if (!CountForExchangeKeys(fbb, exchange_keys_req, SizeToInt(iter_num))) { if (!CountForExchangeKeys(fbb, exchange_keys_req, iter_num)) {
MS_LOG(ERROR) << "count for exchange keys failed."; MS_LOG(ERROR) << "count for exchange keys failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
return true; return true;
} }

View File

@ -25,11 +25,15 @@
#include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/kernel/round/round_kernel_factory.h"
#include "fl/server/executor.h" #include "fl/server/executor.h"
#include "fl/armour/cipher/cipher_keys.h" #include "fl/armour/cipher/cipher_keys.h"
#include "fl/armour/cipher/cipher_meta_storage.h"
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// results of signature verification
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
class ExchangeKeysKernel : public RoundKernel { class ExchangeKeysKernel : public RoundKernel {
public: public:
ExchangeKeysKernel() = default; ExchangeKeysKernel() = default;
@ -43,9 +47,10 @@ class ExchangeKeysKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
size_t iteration_time_window_; size_t iteration_time_window_;
armour::CipherKeys *cipher_key_; armour::CipherKeys *cipher_key_;
bool ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num); sigVerifyResult VerifySignature(const schema::RequestExchangeKeys *exchange_keys_req);
bool ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const size_t iter_num);
bool CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestExchangeKeys *exchange_keys_req, bool CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestExchangeKeys *exchange_keys_req,
const int iter_num); const size_t iter_num);
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -17,6 +17,8 @@
#include "fl/server/kernel/round/get_keys_kernel.h" #include "fl/server/kernel/round/get_keys_kernel.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include <utility>
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
@ -26,24 +28,16 @@ void GetKeysKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
} }
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_key_ = &armour::CipherKeys::GetInstance(); cipher_key_ = &armour::CipherKeys::GetInstance();
} }
bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req, bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
const int iter_num) { const size_t iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(get_keys_req, false); MS_ERROR_IF_NULL_W_RET_VAL(get_keys_req, false);
if (!DistributedCountService::GetInstance().Count(name_, get_keys_req->fl_id()->str())) { if (!DistributedCountService::GetInstance().Count(name_, get_keys_req->fl_id()->str())) {
std::string reason = "Counting for getkeys kernel request failed. Please retry later."; std::string reason = "Counting for getkeys kernel request failed. Please retry later.";
cipher_key_->BuildGetKeysRsp( cipher_key_->BuildGetKeysRsp(
fbb, schema::ResponseCode_OutOfTime, IntToSize(iter_num), fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), false); std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), false);
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
@ -51,16 +45,52 @@ bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const
return true; return true;
} }
sigVerifyResult GetKeysKernel::VerifySignature(const schema::GetExchangeKeys *get_keys_req) {
std::string fl_id = get_keys_req->fl_id()->str();
std::string timestamp = get_keys_req->timestamp()->str();
int iteration = get_keys_req->iteration();
std::string iter_str = std::to_string(iteration);
auto fbs_signature = get_keys_req->signature();
std::vector<unsigned char> signature;
if (fbs_signature == nullptr) {
MS_LOG(ERROR) << "signature in get_keys_req is nullptr";
return sigVerifyResult::FAILED;
}
signature.assign(fbs_signature->begin(), fbs_signature->end());
std::map<std::string, std::string> key_attestations;
const fl::PBMetadata &key_attestations_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
auto iter = key_attestation_pb.key_attestations().begin();
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
}
if (key_attestations.find(fl_id) == key_attestations.end()) {
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
return sigVerifyResult::FAILED;
}
std::vector<unsigned char> src_data;
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
mindspore::ps::server::CertVerify certVerify;
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
return sigVerifyResult::FAILED;
}
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
return sigVerifyResult::TIMEOUT;
}
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
return sigVerifyResult::PASSED;
}
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching GetKeys kernel.";
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); MS_LOG(INFO) << "Launching GetKeys kernel, ITERATION NUMBER IS : " << iter_num;
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetKeysKernel allowed Duration Is " bool response = false;
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1 || outputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid."; std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
@ -74,12 +104,54 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough."; MS_LOG(WARNING) << "Current amount for GetKeysKernel is enough.";
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::GetExchangeKeys>()) {
std::string reason = "The schema of GetExchangeKeys is invalid.";
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
if (get_exchange_keys_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for GetExchangeKeys.";
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(get_exchange_keys_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), false);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::PASSED) {
MS_LOG(INFO) << "verify signature passed!";
}
} }
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
size_t iter_client = IntToSize(get_exchange_keys_req->iteration()); size_t iter_client = IntToSize(get_exchange_keys_req->iteration());
if (iter_num != iter_client) { if (iter_num != iter_client) {
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
@ -91,18 +163,15 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
} }
response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb); response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
if (!response) { if (!response) {
MS_LOG(WARNING) << "get public keys is failed."; MS_LOG(WARNING) << "get public keys not ready.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
if (!CountForGetKeys(fbb, get_exchange_keys_req, SizeToInt(iter_num))) { if (!CountForGetKeys(fbb, get_exchange_keys_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
return true; return true;
} }

View File

@ -30,6 +30,9 @@ namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// results of signature verification
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
class GetKeysKernel : public RoundKernel { class GetKeysKernel : public RoundKernel {
public: public:
GetKeysKernel() = default; GetKeysKernel() = default;
@ -43,8 +46,9 @@ class GetKeysKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
size_t iteration_time_window_; size_t iteration_time_window_;
armour::CipherKeys *cipher_key_; armour::CipherKeys *cipher_key_;
sigVerifyResult VerifySignature(const schema::GetExchangeKeys *get_keys_req);
bool CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req, bool CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
const int iter_num); const size_t iter_num);
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -0,0 +1,281 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fl/server/kernel/round/get_list_sign_kernel.h"
#include <utility>
#include <string>
#include <vector>
#include <memory>
#include <map>
#include "schema/cipher_generated.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
void GetListSignKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
cipher_init_ = &armour::CipherInit::GetInstance();
}
sigVerifyResult GetListSignKernel::VerifySignature(const schema::RequestAllClientListSign *client_list_sign_req) {
std::string fl_id = client_list_sign_req->fl_id()->str();
std::string timestamp = client_list_sign_req->timestamp()->str();
int iteration = client_list_sign_req->iteration();
std::string iter_str = std::to_string(iteration);
auto fbs_signature = client_list_sign_req->signature();
std::vector<unsigned char> signature;
if (fbs_signature == nullptr) {
MS_LOG(ERROR) << "signature in client_list_sign_req is nullptr";
return sigVerifyResult::FAILED;
}
signature.assign(fbs_signature->begin(), fbs_signature->end());
std::map<std::string, std::string> key_attestations;
const fl::PBMetadata &key_attestations_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
auto iter = key_attestation_pb.key_attestations().begin();
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
}
if (key_attestations.find(fl_id) == key_attestations.end()) {
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
return sigVerifyResult::FAILED;
}
std::vector<unsigned char> src_data;
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
mindspore::ps::server::CertVerify certVerify;
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
return sigVerifyResult::FAILED;
}
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
return sigVerifyResult::TIMEOUT;
}
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
return sigVerifyResult::PASSED;
}
bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Launching GetListSign kernel, Iteration number is " << iter_num;
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
std::map<std::string, std::vector<unsigned char>> list_signs;
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::RequestAllClientListSign>()) {
std::string reason = "The schema of RequestAllClientListSign is invalid.";
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestAllClientListSign *get_list_sign_req =
flatbuffers::GetRoot<schema::RequestAllClientListSign>(req_data);
if (get_list_sign_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestAllClientListSign.";
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(get_list_sign_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
iter_num, list_signs);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::PASSED) {
MS_LOG(INFO) << "verify signature passed!";
}
}
size_t iter_client = IntToSize(get_list_sign_req->iteration());
if (iter_num != iter_client) {
MS_LOG(ERROR) << "get list sign iteration number is invalid: server now iteration is " << iter_num
<< ". client request iteration is " << iter_client;
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::string fl_id = get_list_sign_req->fl_id()->str();
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(WARNING) << "Current amount for GetListSignKernel is enough.";
}
if (!GetListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_list_sign_req, fbb)) {
MS_LOG(WARNING) << "get list signs not ready.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) {
std::string reason = "Counting for get list sign request failed. Please retry later. " + count_reason;
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
iter_num, list_signs);
MS_LOG(ERROR) << reason;
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
bool GetListSignKernel::GetListSign(const size_t cur_iterator, const std::string &next_req_time,
const schema::RequestAllClientListSign *get_list_sign_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
MS_LOG(INFO) << "CipherMgr::SendClientListSign START";
std::map<std::string, std::vector<unsigned char>> client_list_signs_empty;
std::map<std::string, std::vector<unsigned char>> client_list_signs_all;
const fl::PBMetadata &clients_sign_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientListSigns);
const fl::ClientListSign &clients_sign_pb = clients_sign_pb_out.client_list_sign();
size_t cur_clients_sign_num = IntToSize(clients_sign_pb.client_list_sign_size());
if (cur_clients_sign_num < cipher_init_->push_list_sign_threshold) {
MS_LOG(INFO) << "The server is not ready. push_list_sign_needed: " << cipher_init_->push_list_sign_threshold;
MS_LOG(INFO) << "now push_sign_client_num: " << clients_sign_pb.client_list_sign_size();
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", next_req_time,
cur_iterator, client_list_signs_empty);
return false;
}
std::vector<string> update_model_clients;
const PBMetadata update_model_clients_pb_out =
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) {
update_model_clients.push_back(update_model_clients_pb.fl_id(SizeToInt(i)));
}
auto iter = clients_sign_pb.client_list_sign().begin();
for (; iter != clients_sign_pb.client_list_sign().end(); ++iter) {
std::vector<unsigned char> signature(iter->second.begin(), iter->second.end());
(void)client_list_signs_all.emplace(std::pair<std::string, std::vector<unsigned char>>(iter->first, signature));
}
std::string fl_id = get_list_sign_req->fl_id()->str();
if (client_list_signs_all.find(fl_id) == client_list_signs_all.end()) {
std::string reason;
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
reason = "client not send list signature, but in update model client list.";
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator,
client_list_signs_all);
} else {
reason = "client not send list signature, && client is illegal";
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator,
client_list_signs_empty);
}
MS_LOG(WARNING) << reason;
return false;
}
if (client_list_signs_all.find(fl_id) != client_list_signs_all.end()) {
// the client has sended signature, return false.
std::string reason = "The server has received the request, please do not request again.";
MS_LOG(WARNING) << reason;
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator,
client_list_signs_all);
return false;
}
std::string reason = "send update model client list signature success. ";
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator,
client_list_signs_all);
MS_LOG(INFO) << "CipherMgr::Send Client ListSign Success";
return true;
}
bool GetListSignKernel::Reset() {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Get List Signature kernel reset!";
DistributedCountService::GetInstance().ResetCounter(name_);
StopTimer();
return true;
}
void GetListSignKernel::BuildGetListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb,
const schema::ResponseCode retcode, const string &reason,
const string &next_req_time, const size_t iteration,
const std::map<std::string, std::vector<unsigned char>> &list_signs) {
auto rsp_reason = fbb->CreateString(reason);
auto rsp_next_req_time = fbb->CreateString(next_req_time);
if (list_signs.size() == 0) {
schema::ReturnAllClientListSignBuilder rsp_builder(*(fbb.get()));
rsp_builder.add_retcode(static_cast<int>(retcode));
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_next_req_time(rsp_next_req_time);
rsp_builder.add_iteration(SizeToInt(iteration));
auto rsp_get_list_sign = rsp_builder.Finish();
fbb->Finish(rsp_get_list_sign);
return;
}
std::vector<flatbuffers::Offset<schema::ClientListSign>> client_list_signs;
for (auto iter = list_signs.begin(); iter != list_signs.end(); ++iter) {
auto fbs_fl_id = fbb->CreateString(iter->first);
auto fbs_sign = fbb->CreateVector(iter->second.data(), iter->second.size());
auto cur_sign = schema::CreateClientListSign(*fbb, fbs_fl_id, fbs_sign);
client_list_signs.push_back(cur_sign);
}
auto all_signs = fbb->CreateVector(client_list_signs);
schema::ReturnAllClientListSignBuilder rsp_builder(*(fbb.get()));
rsp_builder.add_retcode(static_cast<int>(retcode));
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_next_req_time(rsp_next_req_time);
rsp_builder.add_iteration(SizeToInt(iteration));
rsp_builder.add_client_list_sign(all_signs);
auto rsp_get_list_sign = rsp_builder.Finish();
fbb->Finish(rsp_get_list_sign);
return;
}
REG_ROUND_KERNEL(getListSign, GetListSignKernel)
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_LIST_SIGN_KERNEL_H
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_LIST_SIGN_KERNEL_H
#include <string>
#include <vector>
#include <memory>
#include <map>
#include "fl/server/common.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
#include "fl/armour/cipher/cipher_init.h"
#include "fl/server/executor.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
// results of signature verification
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
class GetListSignKernel : public RoundKernel {
public:
GetListSignKernel() = default;
~GetListSignKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Reset() override;
void BuildGetListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const string &reason, const string &next_req_time, const size_t iteration,
const std::map<std::string, std::vector<unsigned char>> &list_signs);
private:
armour::CipherInit *cipher_init_;
Executor *executor_;
size_t iteration_time_window_;
sigVerifyResult VerifySignature(const schema::RequestAllClientListSign *client_list_sign_req);
bool GetListSign(const size_t cur_iterator, const std::string &next_req_time,
const schema::RequestAllClientListSign *client_list_sign_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb);
};
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_LIST_SIGN_KERNEL_H

View File

@ -67,9 +67,9 @@ bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
return true; return true;
} }
(void)++retry_count_; ++retry_count_;
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) { if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
MS_LOG(INFO) << "Launching GetModelKernel retry count is " << retry_count_.load(); MS_LOG(INFO) << "Launching GetModelKernel kernel. Retry count is " << retry_count_.load();
} }
const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data); const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
@ -95,14 +95,14 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp); auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
std::map<std::string, AddressPtr> feature_maps; std::map<std::string, AddressPtr> feature_maps;
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
size_t get_model_iter = IntToSize(get_model_req->iteration()); size_t get_model_iter = static_cast<size_t>(get_model_req->iteration());
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
size_t latest_iter_num = iter_to_model.rbegin()->first; size_t latest_iter_num = iter_to_model.rbegin()->first;
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later. // If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
if ((current_iter == get_model_iter && latest_iter_num != current_iter)) { if (current_iter == get_model_iter && latest_iter_num != current_iter) {
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) + std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" + ". Maybe this is because\n" + "1. Client doesn't not send enough update model request.\n" +
"2. Worker has not push all the weights to servers."; "2. Worker has not push weights to server.";
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps, BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
std::to_string(next_req_time)); std::to_string(next_req_time));
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) { if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
@ -120,7 +120,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter); feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
} }
MS_LOG(INFO) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() MS_LOG(INFO) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter; << ", next request time is " << next_req_time << ", current iteration is " << current_iter;
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter), BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
current_iter, feature_maps, std::to_string(next_req_time)); current_iter, feature_maps, std::to_string(next_req_time));

View File

@ -51,7 +51,7 @@ class GetModelKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
// The time window of one iteration. // The time window of one iteration.
size_t iteration_time_window_; size_t iteration_time_window_{0};
// The count of retrying because the iteration is not finished. // The count of retrying because the iteration is not finished.
std::atomic<uint64_t> retry_count_; std::atomic<uint64_t> retry_count_;

View File

@ -30,23 +30,15 @@ void GetSecretsKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
} }
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_share_ = &armour::CipherShares::GetInstance(); cipher_share_ = &armour::CipherShares::GetInstance();
} }
bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
const schema::GetShareSecrets *get_secrets_req, const int iter_num) { const schema::GetShareSecrets *get_secrets_req, const size_t iter_num) {
MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false); MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false);
if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) { if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) {
std::string reason = "Counting for get secrets kernel request failed. Please retry later."; std::string reason = "Counting for get secrets kernel request failed. Please retry later.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, IntToSize(iter_num), cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
std::to_string(CURRENT_TIME_MILLI.count()), nullptr); std::to_string(CURRENT_TIME_MILLI.count()), nullptr);
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
@ -54,16 +46,52 @@ bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
return true; return true;
} }
sigVerifyResult GetSecretsKernel::VerifySignature(const schema::GetShareSecrets *get_secrets_req) {
std::string fl_id = get_secrets_req->fl_id()->str();
std::string timestamp = get_secrets_req->timestamp()->str();
int iteration = get_secrets_req->iteration();
std::string iter_str = std::to_string(iteration);
auto fbs_signature = get_secrets_req->signature();
std::vector<unsigned char> signature;
if (fbs_signature == nullptr) {
MS_LOG(ERROR) << "signature in get_secrets_req is nullptr";
return sigVerifyResult::FAILED;
}
signature.assign(fbs_signature->begin(), fbs_signature->end());
std::map<std::string, std::string> key_attestations;
const fl::PBMetadata &key_attestations_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
auto iter = key_attestation_pb.key_attestations().begin();
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
}
if (key_attestations.find(fl_id) == key_attestations.end()) {
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
return sigVerifyResult::FAILED;
}
std::vector<unsigned char> src_data;
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
mindspore::ps::server::CertVerify certVerify;
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
return sigVerifyResult::FAILED;
}
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
return sigVerifyResult::TIMEOUT;
}
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
return sigVerifyResult::PASSED;
}
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count()); std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); MS_LOG(INFO) << "Launching get secrets kernel, ITERATION NUMBER IS : " << iter_num;
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1 || outputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid."; std::string reason = "inputs or outputs size is invalid.";
@ -78,8 +106,46 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::GetShareSecrets>()) {
std::string reason = "The schema of GetShareSecrets is invalid.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data); const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
if (get_secrets_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for GetExchangeKeys.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(get_secrets_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::PASSED) {
MS_LOG(INFO) << "verify signature passed!";
}
}
size_t iter_client = IntToSize(get_secrets_req->iteration()); size_t iter_client = IntToSize(get_secrets_req->iteration());
if (iter_num != iter_client) { if (iter_num != iter_client) {
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
@ -93,20 +159,17 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough."; MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
} }
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp); bool response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
if (!response) { if (!response) {
MS_LOG(WARNING) << "get secret shares is failed."; MS_LOG(WARNING) << "get secret shares not ready.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
if (!CountForGetSecrets(fbb, get_secrets_req, SizeToInt(iter_num))) { if (!CountForGetSecrets(fbb, get_secrets_req, iter_num)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
return true; return true;
} }

View File

@ -29,6 +29,9 @@ namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// results of signature verification
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
class GetSecretsKernel : public RoundKernel { class GetSecretsKernel : public RoundKernel {
public: public:
GetSecretsKernel() = default; GetSecretsKernel() = default;
@ -42,8 +45,9 @@ class GetSecretsKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
size_t iteration_time_window_; size_t iteration_time_window_;
armour::CipherShares *cipher_share_; armour::CipherShares *cipher_share_;
sigVerifyResult VerifySignature(const schema::GetShareSecrets *get_secrets_req);
bool CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::GetShareSecrets *get_secrets_req, bool CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::GetShareSecrets *get_secrets_req,
const int iter_num); const size_t iter_num);
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -37,6 +37,13 @@ void PullWeightKernel::InitKernel(size_t) {
bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
MS_LOG(DEBUG) << "Launching PullWeightKernel kernel."; MS_LOG(DEBUG) << "Launching PullWeightKernel kernel.";
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
void *req_data = inputs[0]->addr; void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) { if (fbb == nullptr || req_data == nullptr) {
@ -71,7 +78,7 @@ void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
} }
std::map<std::string, AddressPtr> feature_maps = {}; std::map<std::string, AddressPtr> feature_maps = {};
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
size_t pull_weight_iter = IntToSize(pull_weight_req->iteration()); size_t pull_weight_iter = static_cast<size_t>(pull_weight_req->iteration());
// The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry. // The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry.
if (pull_weight_iter != current_iter) { if (pull_weight_iter != current_iter) {
std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) + std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) +
@ -91,7 +98,7 @@ void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
weight_names.push_back(weights_names_fbs->Get(i)->str()); weight_names.push_back(weights_names_fbs->Get(i)->str());
} }
if (!executor_->IsWeightAggrDone(weight_names) || !executor_->unmasked()) { if (!executor_->IsWeightAggrDone(weight_names) || !executor_->unmasked()) {
(void)++retry_count_; ++retry_count_;
std::string reason = "The aggregation for the weights is not done yet."; std::string reason = "The aggregation for the weights is not done yet.";
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps); BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps);
if (retry_count_.load() % kPrintPullWeightForEveryRetryTime == 1) { if (retry_count_.load() % kPrintPullWeightForEveryRetryTime == 1) {
@ -134,7 +141,7 @@ void PullWeightKernel::BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb,
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get())); schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get()));
rsp_pull_weight_builder.add_retcode(static_cast<int>(retcode)); rsp_pull_weight_builder.add_retcode(SizeToInt(retcode));
rsp_pull_weight_builder.add_reason(fbs_reason); rsp_pull_weight_builder.add_reason(fbs_reason);
rsp_pull_weight_builder.add_iteration(SizeToInt(iteration)); rsp_pull_weight_builder.add_iteration(SizeToInt(iteration));
rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector); rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector);

View File

@ -0,0 +1,277 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fl/server/kernel/round/push_list_sign_kernel.h"
#include <utility>
#include <string>
#include <vector>
#include <memory>
#include <map>
#include "schema/cipher_generated.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
void PushListSignKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
}
cipher_init_ = &armour::CipherInit::GetInstance();
}
bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Launching PushListSignKernel, Iteration number is " << iter_num;
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
return false;
}
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) {
std::string reason = "FBBuilder builder or req_data is nullptr.";
MS_LOG(ERROR) << reason;
return false;
}
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::SendClientListSign>()) {
std::string reason = "The schema of PushClientListSign is invalid.";
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::SendClientListSign *client_list_sign_req = flatbuffers::GetRoot<schema::SendClientListSign>(req_data);
if (client_list_sign_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for PushClientListSign.";
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(client_list_sign_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::PASSED) {
MS_LOG(INFO) << "verify signature passed!";
}
}
return LaunchForPushListSign(client_list_sign_req, iter_num, fbb, outputs);
}
bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req,
const size_t &iter_num, const std::shared_ptr<server::FBBuilder> &fbb,
const std::vector<AddressPtr> &outputs) {
size_t iter_client = IntToSize(client_list_sign_req->iteration());
if (iter_num != iter_client) {
std::string reason = "push list sign iteration number is invalid";
MS_LOG(WARNING) << reason;
MS_LOG(WARNING) << "server now iteration is " << iter_num << ". client request iteration is " << iter_client;
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
iter_num);
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::vector<string> update_model_clients;
const PBMetadata update_model_clients_pb_out =
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) {
update_model_clients.push_back(update_model_clients_pb.fl_id(i));
}
std::string fl_id = client_list_sign_req->fl_id()->str();
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for PushListSignKernel is enough.";
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
// client in get update model client list.
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, "Current amount for PushListSignKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
} else {
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime,
"Current amount for PushListSignKernel is enough.",
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!PushListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), client_list_sign_req, fbb,
update_model_clients)) {
MS_LOG(ERROR) << "push client list sign failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
std::string count_reason = "";
if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) {
std::string reason = "Counting for push list sign request failed. Please retry later. " + count_reason;
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
iter_num);
MS_LOG(ERROR) << reason;
return true;
}
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
sigVerifyResult PushListSignKernel::VerifySignature(const schema::SendClientListSign *client_list_sign_req) {
std::string fl_id = client_list_sign_req->fl_id()->str();
std::string timestamp = client_list_sign_req->timestamp()->str();
int iteration = client_list_sign_req->iteration();
std::string iter_str = std::to_string(iteration);
auto fbs_signature = client_list_sign_req->req_signature();
std::vector<unsigned char> signature;
if (fbs_signature == nullptr) {
MS_LOG(ERROR) << "signature in client_list_sign_req is nullptr";
return sigVerifyResult::FAILED;
}
signature.assign(fbs_signature->begin(), fbs_signature->end());
std::map<std::string, std::string> key_attestations;
const fl::PBMetadata &key_attestations_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
auto iter = key_attestation_pb.key_attestations().begin();
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
}
if (key_attestations.find(fl_id) == key_attestations.end()) {
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
return sigVerifyResult::FAILED;
}
std::vector<unsigned char> src_data;
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
mindspore::ps::server::CertVerify certVerify;
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
return sigVerifyResult::FAILED;
}
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
return sigVerifyResult::TIMEOUT;
}
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
return sigVerifyResult::PASSED;
}
bool PushListSignKernel::PushListSign(const size_t cur_iterator, const std::string &next_req_time,
const schema::SendClientListSign *client_list_sign_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb,
const std::vector<std::string> &update_model_clients) {
MS_LOG(INFO) << "CipherMgr::PushClientListSign START";
std::vector<std::string> get_client_list; // the clients which get update model client list
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetUpdateModelClientList,
&get_client_list);
std::string fl_id = client_list_sign_req->fl_id()->str();
if (find(get_client_list.begin(), get_client_list.end(), fl_id) == get_client_list.end()) {
// client not in get update model client list.
std::string reason = "client send signature is not in get update model client list. && client is illegal";
MS_LOG(WARNING) << reason;
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
// client in update model client list, client can move to next round
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator);
} else {
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator);
}
return false;
}
std::vector<std::string> send_signs_clients;
const fl::PBMetadata &clients_sign_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientListSigns);
const fl::ClientListSign &clients_sign_pb = clients_sign_pb_out.client_list_sign();
auto iter = clients_sign_pb.client_list_sign().begin();
for (; iter != clients_sign_pb.client_list_sign().end(); ++iter) {
send_signs_clients.push_back(iter->first);
}
if (find(send_signs_clients.begin(), send_signs_clients.end(), fl_id) != send_signs_clients.end()) {
// the client has sended signature, return false.
std::string reason = "The server has received the request, please do not request again.";
MS_LOG(ERROR) << reason;
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator);
return false;
}
auto fbs_signature = client_list_sign_req->signature();
std::vector<char> signature;
if (fbs_signature != nullptr) {
signature.assign(fbs_signature->begin(), fbs_signature->end());
}
fl::PairClientListSign pair_client_list_sign_pb;
pair_client_list_sign_pb.set_fl_id(fl_id);
pair_client_list_sign_pb.set_signature(signature.data(), signature.size());
fl::PBMetadata pb_data;
pb_data.mutable_pair_client_list_sign()->MergeFrom(pair_client_list_sign_pb);
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxClientListSigns, pb_data);
if (!retcode) {
std::string reason = "store client list signature failed";
MS_LOG(ERROR) << reason;
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator);
return false;
}
std::string reason = "send update model client list signature success. ";
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator);
MS_LOG(INFO) << "CipherMgr::PushClientListSign Success";
return true;
}
bool PushListSignKernel::Reset() {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
MS_LOG(INFO) << "Push list sign kernel reset!";
DistributedCountService::GetInstance().ResetCounter(name_);
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxClientListSigns);
StopTimer();
return true;
}
void PushListSignKernel::BuildPushListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb,
const schema::ResponseCode retcode, const string &reason,
const string &next_req_time, const size_t iteration) {
auto rsp_reason = fbb->CreateString(reason);
auto rsp_next_req_time = fbb->CreateString(next_req_time);
schema::ResponseClientListSignBuilder rsp_builder(*(fbb.get()));
rsp_builder.add_retcode(static_cast<int>(retcode));
rsp_builder.add_reason(rsp_reason);
rsp_builder.add_next_req_time(rsp_next_req_time);
rsp_builder.add_iteration(SizeToInt(iteration));
auto rsp_push_list_sign = rsp_builder.Finish();
fbb->Finish(rsp_push_list_sign);
return;
}
REG_ROUND_KERNEL(pushListSign, PushListSignKernel)
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_LIST_SIGN_KERNEL_H
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_LIST_SIGN_KERNEL_H
#include <vector>
#include <string>
#include <memory>
#include "fl/server/common.h"
#include "fl/server/kernel/round/round_kernel.h"
#include "fl/server/kernel/round/round_kernel_factory.h"
#include "fl/armour/cipher/cipher_init.h"
#include "fl/server/executor.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
// results of signature verification
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
class PushListSignKernel : public RoundKernel {
public:
PushListSignKernel() = default;
~PushListSignKernel() override = default;
void InitKernel(size_t required_cnt) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req, const size_t &iter_num,
const std::shared_ptr<server::FBBuilder> &fbb, const std::vector<AddressPtr> &outputs);
bool Reset() override;
void BuildPushListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
const string &reason, const string &next_req_time, const size_t iteration);
private:
armour::CipherInit *cipher_init_;
Executor *executor_;
size_t iteration_time_window_;
sigVerifyResult VerifySignature(const schema::SendClientListSign *client_list_sign_req);
bool PushListSign(const size_t cur_iterator, const std::string &next_req_time,
const schema::SendClientListSign *client_list_sign_req,
const std::shared_ptr<fl::server::FBBuilder> &fbb,
const std::vector<std::string> &update_model_clients);
};
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_LIST_SIGN_KERNEL_H

View File

@ -98,7 +98,7 @@ ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb,
void PushMetricsKernel::BuildPushMetricsRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode) { void PushMetricsKernel::BuildPushMetricsRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode) {
MS_ERROR_IF_NULL_WO_RET_VAL(fbb); MS_ERROR_IF_NULL_WO_RET_VAL(fbb);
schema::ResponsePushMetricsBuilder rsp_push_metrics_builder(*(fbb.get())); schema::ResponsePushMetricsBuilder rsp_push_metrics_builder(*(fbb.get()));
rsp_push_metrics_builder.add_retcode(static_cast<int>(retcode)); rsp_push_metrics_builder.add_retcode(retcode);
auto rsp_push_metrics = rsp_push_metrics_builder.Finish(); auto rsp_push_metrics = rsp_push_metrics_builder.Finish();
fbb->Finish(rsp_push_metrics); fbb->Finish(rsp_push_metrics);
} }

View File

@ -33,6 +33,13 @@ void PushWeightKernel::InitKernel(size_t) {
bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
MS_LOG(INFO) << "Launching PushWeightKernel kernel."; MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, reason.c_str(), reason.size());
return true;
}
void *req_data = inputs[0]->addr; void *req_data = inputs[0]->addr;
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>(); std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
if (fbb == nullptr || req_data == nullptr) { if (fbb == nullptr || req_data == nullptr) {
@ -141,7 +148,7 @@ void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb,
} }
auto fbs_reason = fbb->CreateString(reason); auto fbs_reason = fbb->CreateString(reason);
schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get())); schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
rsp_push_weight_builder.add_retcode(static_cast<int>(retcode)); rsp_push_weight_builder.add_retcode(SizeToInt(retcode));
rsp_push_weight_builder.add_reason(fbs_reason); rsp_push_weight_builder.add_reason(fbs_reason);
rsp_push_weight_builder.add_iteration(SizeToInt(iteration)); rsp_push_weight_builder.add_iteration(SizeToInt(iteration));
auto rsp_push_weight = rsp_push_weight_builder.Finish(); auto rsp_push_weight = rsp_push_weight_builder.Finish();

View File

@ -18,6 +18,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include <utility>
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
@ -27,7 +29,6 @@ void ReconstructSecretsKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
} }
auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) { auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kReconstructSeccrets) { if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kReconstructSeccrets) {
MS_LOG(INFO) << "start FinishIteration"; MS_LOG(INFO) << "start FinishIteration";
@ -44,14 +45,52 @@ void ReconstructSecretsKernel::InitKernel(size_t) {
{first_cnt_handler, last_cnt_handler}); {first_cnt_handler, last_cnt_handler});
} }
sigVerifyResult ReconstructSecretsKernel::VerifySignature(const schema::SendReconstructSecret *reconstruct_secret_req) {
std::string fl_id = reconstruct_secret_req->fl_id()->str();
std::string timestamp = reconstruct_secret_req->timestamp()->str();
int iteration = reconstruct_secret_req->iteration();
std::string iter_str = std::to_string(iteration);
auto fbs_signature = reconstruct_secret_req->signature();
std::vector<unsigned char> signature;
if (fbs_signature == nullptr) {
MS_LOG(ERROR) << "signature in reconstruct_secret_req is nullptr";
return sigVerifyResult::FAILED;
}
signature.assign(fbs_signature->begin(), fbs_signature->end());
std::map<std::string, std::string> key_attestations;
const fl::PBMetadata &key_attestations_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
auto iter = key_attestation_pb.key_attestations().begin();
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
}
if (key_attestations.find(fl_id) == key_attestations.end()) {
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
return sigVerifyResult::FAILED;
}
std::vector<unsigned char> src_data;
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
mindspore::ps::server::CertVerify certVerify;
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
return sigVerifyResult::FAILED;
}
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
return sigVerifyResult::TIMEOUT;
}
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
return sigVerifyResult::PASSED;
}
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
bool response = false; bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); MS_LOG(INFO) << "Launching ReconstructSecrets Kernel, Iteration number is " << iter_num;
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ReconstructSecretsKernel total duration is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1 || outputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size(); MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
@ -73,14 +112,55 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList); DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list(); const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
for (int i = 0; i < update_model_clients_pb.fl_id_size(); ++i) { for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) {
update_model_clients.push_back(update_model_clients_pb.fl_id(i)); update_model_clients.push_back(update_model_clients_pb.fl_id(i));
} }
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::SendReconstructSecret>()) {
std::string reason = "The schema of SendReconstructSecret is invalid.";
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num),
std::to_string(CURRENT_TIME_MILLI.count()));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::SendReconstructSecret *reconstruct_secret_req = const schema::SendReconstructSecret *reconstruct_secret_req =
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data); flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
std::string fl_id = reconstruct_secret_req->fl_id()->str(); if (reconstruct_secret_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for SendReconstructSecret.";
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num),
std::to_string(CURRENT_TIME_MILLI.count()));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(reconstruct_secret_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
SizeToInt(iter_num), std::to_string(CURRENT_TIME_MILLI.count()));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason, SizeToInt(iter_num),
std::to_string(CURRENT_TIME_MILLI.count()));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::PASSED) {
MS_LOG(INFO) << "verify signature passed!";
}
}
std::string fl_id = reconstruct_secret_req->fl_id()->str();
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough."; MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) { if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
@ -106,11 +186,10 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough."; MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough.";
} }
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); MS_LOG(INFO) << "reconstruct_secrets_kernel success.";
MS_LOG(INFO) << "reconstruct_secrets_kernel success time is : " << duration;
if (!response) { if (!response) {
MS_LOG(INFO) << "reconstruct_secrets_kernel response is false."; MS_LOG(INFO) << "reconstruct_secrets_kernel response not ready.";
} }
return true; return true;
} }

View File

@ -30,6 +30,9 @@ namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// results of signature verification
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
class ReconstructSecretsKernel : public RoundKernel { class ReconstructSecretsKernel : public RoundKernel {
public: public:
ReconstructSecretsKernel() = default; ReconstructSecretsKernel() = default;
@ -43,8 +46,10 @@ class ReconstructSecretsKernel : public RoundKernel {
private: private:
std::string name_unmask_; std::string name_unmask_;
Executor *executor_;
size_t iteration_time_window_{0}; size_t iteration_time_window_{0};
armour::CipherReconStruct cipher_reconstruct_; armour::CipherReconStruct cipher_reconstruct_;
sigVerifyResult VerifySignature(const schema::SendReconstructSecret *reconstruct_secret_req);
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -124,7 +124,7 @@ void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const v
outputs[0]->size = len; outputs[0]->size = len;
std::unique_lock<std::mutex> lock(heap_data_mtx_); std::unique_lock<std::mutex> lock(heap_data_mtx_);
(void)heap_data_.emplace(outputs[0], std::move(output_data)); (void)heap_data_.insert(std::make_pair(outputs[0], std::move(output_data)));
return; return;
} }
} // namespace kernel } // namespace kernel

View File

@ -26,7 +26,7 @@
#include <utility> #include <utility>
#include <chrono> #include <chrono>
#include <thread> #include <thread>
#include "utils/hash_map.h" #include <unordered_map>
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "fl/server/common.h" #include "fl/server/common.h"
@ -120,7 +120,7 @@ class RoundKernel : virtual public CPUKernel {
std::mutex release_mtx_; std::mutex release_mtx_;
std::queue<AddressPtr> heap_data_to_release_; std::queue<AddressPtr> heap_data_to_release_;
std::mutex heap_data_mtx_; std::mutex heap_data_mtx_;
mindspore::HashMap<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_; std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -20,8 +20,9 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include "utils/hash_map.h" #include <unordered_map>
#include "fl/server/common.h" #include "fl/server/common.h"
#include "fl/server/cert_verify.h"
#include "fl/server/kernel/round/round_kernel.h" #include "fl/server/kernel/round/round_kernel.h"
namespace mindspore { namespace mindspore {
@ -42,7 +43,7 @@ class RoundKernelFactory {
RoundKernelFactory(const RoundKernelFactory &) = delete; RoundKernelFactory(const RoundKernelFactory &) = delete;
RoundKernelFactory &operator=(const RoundKernelFactory &) = delete; RoundKernelFactory &operator=(const RoundKernelFactory &) = delete;
mindspore::HashMap<std::string, RoundKernelCreator> name_to_creator_map_; std::unordered_map<std::string, RoundKernelCreator> name_to_creator_map_;
}; };
class RoundKernelRegister { class RoundKernelRegister {

View File

@ -17,6 +17,8 @@
#include "fl/server/kernel/round/share_secrets_kernel.h" #include "fl/server/kernel/round/share_secrets_kernel.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <map>
#include <utility>
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
@ -26,13 +28,6 @@ void ShareSecretsKernel::InitKernel(size_t) {
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
} }
executor_ = &Executor::GetInstance();
MS_EXCEPTION_IF_NULL(executor_);
if (!executor_->initialized()) {
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
return;
}
cipher_share_ = &armour::CipherShares::GetInstance(); cipher_share_ = &armour::CipherShares::GetInstance();
} }
@ -49,21 +44,58 @@ bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> &
return true; return true;
} }
sigVerifyResult ShareSecretsKernel::VerifySignature(const schema::RequestShareSecrets *share_secrets_req) {
std::string fl_id = share_secrets_req->fl_id()->str();
std::string timestamp = share_secrets_req->timestamp()->str();
int iteration = share_secrets_req->iteration();
std::string iter_str = std::to_string(iteration);
auto fbs_signature = share_secrets_req->signature();
std::vector<unsigned char> signature;
if (fbs_signature == nullptr) {
MS_LOG(ERROR) << "signature in share_secrets_req is nullptr";
return sigVerifyResult::FAILED;
}
signature.assign(fbs_signature->begin(), fbs_signature->end());
std::map<std::string, std::string> key_attestations;
const fl::PBMetadata &key_attestations_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
auto iter = key_attestation_pb.key_attestations().begin();
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
}
if (key_attestations.find(fl_id) == key_attestations.end()) {
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
return sigVerifyResult::FAILED;
}
std::vector<unsigned char> src_data;
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
mindspore::ps::server::CertVerify certVerify;
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
return sigVerifyResult::FAILED;
}
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
return sigVerifyResult::TIMEOUT;
}
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
return sigVerifyResult::PASSED;
}
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
bool response = false; bool response = false;
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); MS_LOG(INFO) << "Launching ShareSecretsKernel, ITERATION NUMBER IS : " << iter_num;
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is "
<< total_duration;
clock_t start_time = clock();
if (inputs.size() != 1 || outputs.size() != 1) { if (inputs.size() != 1 || outputs.size() != 1) {
std::string reason = "inputs or outputs size is invalid."; std::string reason = "inputs or outputs size is invalid.";
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>(); std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
void *req_data = inputs[0]->addr; void *req_data = inputs[0]->addr;
if (fbb == nullptr || req_data == nullptr) { if (fbb == nullptr || req_data == nullptr) {
@ -71,7 +103,6 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
MS_LOG(ERROR) << reason; MS_LOG(ERROR) << reason;
return false; return false;
} }
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough."; MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
@ -80,7 +111,50 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
if (!verifier.VerifyBuffer<schema::RequestShareSecrets>()) {
std::string reason = "The schema of RequestShareSecrets is invalid.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data); const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
if (share_secrets_req == nullptr) {
std::string reason = "Building flatbuffers schema failed for RequestShareSecrets.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(share_secrets_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::PASSED) {
MS_LOG(INFO) << "verify signature passed!";
}
}
size_t iter_client = IntToSize(share_secrets_req->iteration()); size_t iter_client = IntToSize(share_secrets_req->iteration());
if (iter_num != iter_client) { if (iter_num != iter_client) {
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
@ -93,7 +167,7 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
response = cipher_share_->ShareSecrets(SizeToInt(iter_num), share_secrets_req, fbb, response = cipher_share_->ShareSecrets(SizeToInt(iter_num), share_secrets_req, fbb,
std::to_string(CURRENT_TIME_MILLI.count())); std::to_string(CURRENT_TIME_MILLI.count()));
if (!response) { if (!response) {
MS_LOG(WARNING) << "update secret shares is failed."; MS_LOG(ERROR) << "update secret shares is failed.";
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true; return true;
} }
@ -102,9 +176,6 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
return true; return true;
} }
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
clock_t end_time = clock();
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
return true; return true;
} }

View File

@ -30,6 +30,9 @@ namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
namespace kernel { namespace kernel {
// results of signature verification
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
class ShareSecretsKernel : public RoundKernel { class ShareSecretsKernel : public RoundKernel {
public: public:
ShareSecretsKernel() = default; ShareSecretsKernel() = default;
@ -43,6 +46,7 @@ class ShareSecretsKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
size_t iteration_time_window_; size_t iteration_time_window_;
armour::CipherShares *cipher_share_; armour::CipherShares *cipher_share_;
sigVerifyResult VerifySignature(const schema::RequestShareSecrets *share_secrets_req);
bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req, bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req,
const size_t iter_num); const size_t iter_num);
}; };

View File

@ -19,11 +19,11 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "fl/server/model_store.h"
#include "fl/server/iteration.h"
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
#include "fl/armour/cipher/cipher_init.h" #include "fl/armour/cipher/cipher_init.h"
#endif #endif
#include "fl/server/model_store.h"
#include "fl/server/iteration.h"
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
@ -46,6 +46,9 @@ void StartFLJobKernel::InitKernel(size_t) {
PBMetadata devices_metas; PBMetadata devices_metas;
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas); DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas);
PBMetadata client_key_attestation;
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxClientKeyAttestation, client_key_attestation);
return; return;
} }
@ -93,6 +96,17 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
return true; return true;
} }
if (ps::PSContext::instance()->pki_verify()) {
if (!JudgeFLJobCert(fbb, start_fl_job_req)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (!StoreKeyAttestation(fbb, start_fl_job_req)) {
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
}
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req); DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
result_code = ReadyForStartFLJob(fbb, device_meta); result_code = ReadyForStartFLJob(fbb, device_meta);
if (result_code != ResultCode::kSuccess) { if (result_code != ResultCode::kSuccess) {
@ -122,16 +136,87 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
return true; return true;
} }
bool StartFLJobKernel::JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
std::string fl_id = start_fl_job_req->fl_id()->str();
std::string timestamp = start_fl_job_req->timestamp()->str();
auto sign_data_vector = start_fl_job_req->sign_data();
if (sign_data_vector == nullptr || sign_data_vector->size() == 0) {
std::string reason = "sign data is empty.";
BuildStartFLJobRsp(
fbb, schema::ResponseCode_RequestError, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
return false;
}
unsigned char sign_data[sign_data_vector->size()];
for (unsigned int i = 0; i < sign_data_vector->size(); i++) {
sign_data[i] = sign_data_vector->Get(i);
}
std::string key_attestation = start_fl_job_req->key_attestation()->str();
std::string equip_cert = start_fl_job_req->equip_cert()->str();
std::string equip_ca_cert = start_fl_job_req->equip_ca_cert()->str();
std::string root_first_ca_path = ps::PSContext::instance()->root_first_ca_path();
std::string root_second_ca_path = ps::PSContext::instance()->root_second_ca_path();
std::string equip_crl_path = ps::PSContext::instance()->equip_crl_path();
mindspore::ps::server::CertVerify certVerify;
bool ret =
certVerify.verifyCertAndSign(fl_id, timestamp, (const unsigned char *)sign_data, key_attestation, equip_cert,
equip_ca_cert, root_first_ca_path, root_second_ca_path, equip_crl_path);
if (!ret) {
std::string reason = "startFLJob sign and certificate verify failed.";
BuildStartFLJobRsp(
fbb, schema::ResponseCode_RequestError, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(ERROR) << reason;
} else {
MS_LOG(INFO) << "JudgeFLJobVerify success." << ret;
}
return ret;
}
bool StartFLJobKernel::StoreKeyAttestation(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestFLJob *start_fl_job_req) {
// update key attestation
if (start_fl_job_req == nullptr) {
return false;
}
std::string fl_id = start_fl_job_req->fl_id()->str();
std::string key_attestation = start_fl_job_req->key_attestation()->str();
fl::PairKeyAttestation pair_key_attestation_pb;
pair_key_attestation_pb.set_fl_id(fl_id);
pair_key_attestation_pb.set_certificate(key_attestation);
fl::PBMetadata pb_data;
pb_data.mutable_pair_key_attestation()->MergeFrom(pair_key_attestation_pb);
bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxClientKeyAttestation, pb_data);
if (!ret) {
std::string reason = "startFLJob: store key attestation failed";
MS_LOG(ERROR) << reason;
BuildStartFLJobRsp(
fbb, schema::ResponseCode_OutOfTime, reason, false,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
return false;
}
return true;
}
bool StartFLJobKernel::Reset() { bool StartFLJobKernel::Reset() {
MS_LOG(INFO) << "Starting fl job kernel reset!"; MS_LOG(INFO) << "Starting fl job kernel reset!";
StopTimer(); StopTimer();
DistributedCountService::GetInstance().ResetCounter(name_); DistributedCountService::GetInstance().ResetCounter(name_);
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas); DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas);
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxClientKeyAttestation);
return true; return true;
} }
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) { void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
iter_next_req_timestamp_ = LongToSize(CURRENT_TIME_MILLI.count()) + iteration_time_window_; iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_); LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
// The first startFLJob request means a new iteration starts running. // The first startFLJob request means a new iteration starts running.
Iteration::GetInstance().SetIterationRunning(); Iteration::GetInstance().SetIterationRunning();
@ -241,9 +326,11 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
fl_plan_builder.add_epochs(SizeToInt(ps::PSContext::instance()->client_epoch_num())); fl_plan_builder.add_epochs(SizeToInt(ps::PSContext::instance()->client_epoch_num()));
fl_plan_builder.add_mini_batch(SizeToInt(ps::PSContext::instance()->client_batch_size())); fl_plan_builder.add_mini_batch(SizeToInt(ps::PSContext::instance()->client_batch_size()));
fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate()); fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate());
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
fl_plan_builder.add_cipher(cipher_public_params); fl_plan_builder.add_cipher(cipher_public_params);
#endif #endif
auto fbs_fl_plan = fl_plan_builder.Finish(); auto fbs_fl_plan = fl_plan_builder.Finish();
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps; std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;

View File

@ -58,6 +58,10 @@ class StartFLJobKernel : public RoundKernel {
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta); void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
bool JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
bool StoreKeyAttestation(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
// Build response for startFLJob round no matter success or failure. // Build response for startFLJob round no matter success or failure.
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const bool is_selected, const std::string &next_req_time, const std::string &reason, const bool is_selected, const std::string &next_req_time,

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include "fl/server/kernel/round/update_model_kernel.h" #include "fl/server/kernel/round/update_model_kernel.h"
namespace mindspore { namespace mindspore {
@ -85,6 +86,29 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
return true; return true;
} }
// verify signature
if (ps::PSContext::instance()->pki_verify()) {
sigVerifyResult verify_result = VerifySignature(update_model_req);
if (verify_result == sigVerifyResult::FAILED) {
std::string reason = "verify signature failed.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::TIMEOUT) {
std::string reason = "verify signature timestamp failed.";
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, "");
MS_LOG(ERROR) << reason;
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
return true;
}
if (verify_result == sigVerifyResult::PASSED) {
MS_LOG(INFO) << "verify signature passed!";
}
}
result_code = UpdateModel(update_model_req, fbb); result_code = UpdateModel(update_model_req, fbb);
if (result_code != ResultCode::kSuccess) { if (result_code != ResultCode::kSuccess) {
MS_LOG(ERROR) << "Updating model failed."; MS_LOG(ERROR) << "Updating model failed.";
@ -144,12 +168,11 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
size_t iteration = IntToSize(update_model_req->iteration()); size_t iteration = IntToSize(update_model_req->iteration());
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) + ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) +
". Retry later."; ". Retry later at time: " + std::to_string(next_req_time);
BuildUpdateModelRsp( BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(next_req_time));
fbb, schema::ResponseCode_OutOfTime, reason,
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
MS_LOG(WARNING) << reason; MS_LOG(WARNING) << reason;
return ResultCode::kSuccessAndReturn; return ResultCode::kSuccessAndReturn;
} }
@ -259,6 +282,47 @@ ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilde
return ResultCode::kSuccess; return ResultCode::kSuccess;
} }
sigVerifyResult UpdateModelKernel::VerifySignature(const schema::RequestUpdateModel *update_model_req) {
std::string fl_id = update_model_req->fl_id()->str();
std::string timestamp = update_model_req->timestamp()->str();
int iteration = update_model_req->iteration();
std::string iter_str = std::to_string(iteration);
auto fbs_signature = update_model_req->signature();
std::vector<unsigned char> signature;
if (fbs_signature == nullptr) {
MS_LOG(ERROR) << "signature in client_list_sign_req is nullptr";
return sigVerifyResult::FAILED;
}
signature.assign(fbs_signature->begin(), fbs_signature->end());
std::map<std::string, std::string> key_attestations;
const fl::PBMetadata &key_attestations_pb_out =
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
auto iter = key_attestation_pb.key_attestations().begin();
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
}
if (key_attestations.find(fl_id) == key_attestations.end()) {
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
return sigVerifyResult::TIMEOUT;
}
std::vector<unsigned char> src_data;
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
mindspore::ps::server::CertVerify certVerify;
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
return sigVerifyResult::FAILED;
}
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
return sigVerifyResult::TIMEOUT;
}
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
return sigVerifyResult::PASSED;
}
void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time) { const std::string &reason, const std::string &next_req_time) {
if (fbb == nullptr) { if (fbb == nullptr) {

View File

@ -35,10 +35,12 @@ namespace server {
namespace kernel { namespace kernel {
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round. // The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
constexpr uint64_t kInitialDataSizeSum = 0; constexpr uint64_t kInitialDataSizeSum = 0;
// results of signature verification
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
class UpdateModelKernel : public RoundKernel { class UpdateModelKernel : public RoundKernel {
public: public:
UpdateModelKernel() : executor_(nullptr), iteration_time_window_(0) {} UpdateModelKernel() = default;
~UpdateModelKernel() override = default; ~UpdateModelKernel() override = default;
void InitKernel(size_t threshold_count) override; void InitKernel(size_t threshold_count) override;
@ -55,6 +57,7 @@ class UpdateModelKernel : public RoundKernel {
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req); std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb, ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
const schema::RequestUpdateModel *update_model_req); const schema::RequestUpdateModel *update_model_req);
sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req);
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
const std::string &reason, const std::string &next_req_time); const std::string &reason, const std::string &next_req_time);
@ -62,7 +65,7 @@ class UpdateModelKernel : public RoundKernel {
Executor *executor_; Executor *executor_;
// The time window of one iteration. // The time window of one iteration.
size_t iteration_time_window_; size_t iteration_time_window_{0};
}; };
} // namespace kernel } // namespace kernel
} // namespace server } // namespace server

View File

@ -1,35 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fl/server/kernel/sgd_kernel.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
REG_OPTIMIZER_KERNEL(SGD,
ParamsInfo()
.AddInputNameType(kWeight, kNumberTypeFloat32)
.AddInputNameType(kGradient, kNumberTypeFloat32)
.AddInputNameType(kLearningRate, kNumberTypeFloat32)
.AddInputNameType(kAccumulation, kNumberTypeFloat32)
.AddInputNameType(kMomentum, kNumberTypeFloat32)
.AddInputNameType(kStat, kNumberTypeFloat32),
SGDKernel, float)
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -1,63 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_
#include <vector>
#include <memory>
#include <utility>
#include "backend/kernel_compiler/cpu/sgd_cpu_kernel.h"
#include "fl/server/kernel/optimizer_kernel.h"
#include "fl/server/kernel/optimizer_kernel_factory.h"
namespace mindspore {
namespace fl {
namespace server {
namespace kernel {
using mindspore::kernel::SGDCPUKernel;
template <typename T>
class SGDKernel : public SGDCPUKernel<T>, public OptimizerKernel {
public:
SGDKernel() = default;
~SGDKernel() override = default;
void InitKernel(const CNodePtr &cnode) override {
SGDCPUKernel<T>::InitKernel(cnode);
InitServerKernelInputOutputSize(cnode);
GenerateReuseKernelNodeInfo();
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return SGDCPUKernel<T>::Launch(inputs, workspace, outputs);
}
void GenerateReuseKernelNodeInfo() override {
MS_LOG(INFO) << "SGD reuse 'weight', 'learning rate', 'accumulation', 'momentum' and 'stat' of the kernel node.";
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 3));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kMomentum, 4));
reuse_kernel_node_inputs_info_.insert(std::make_pair(kStat, 5));
return;
}
};
} // namespace kernel
} // namespace server
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_

View File

@ -20,7 +20,7 @@
#include <any> #include <any>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include "utils/hash_map.h" #include <unordered_map>
#include "fl/server/common.h" #include "fl/server/common.h"
namespace mindspore { namespace mindspore {
@ -77,10 +77,10 @@ class LocalMetaStore {
LocalMetaStore &operator=(const LocalMetaStore &) = delete; LocalMetaStore &operator=(const LocalMetaStore &) = delete;
// key_to_meta_ stores metadata with key-value format. // key_to_meta_ stores metadata with key-value format.
mindspore::HashMap<std::string, std::any> key_to_meta_; std::unordered_map<std::string, std::any> key_to_meta_;
// This mutex makes sure that the operations on key_to_meta_ is threadsafe. // This mutex makes sure that the operations on key_to_meta_ is threadsafe.
std::mutex mtx_; std::mutex mtx_;
size_t curr_iter_num_; size_t curr_iter_num_{0};
}; };
} // namespace server } // namespace server
} // namespace fl } // namespace fl

View File

@ -31,7 +31,7 @@ namespace server {
constexpr size_t kInitIterationNum = 0; constexpr size_t kInitIterationNum = 0;
// The initial iteration number after ModelStore is reset. // The initial iteration number after ModelStore is reset.
constexpr size_t kResetInitIterNum = 1; constexpr size_t kResetInitialIterNum = 1;
// Server framework use ModelStore to store and query models. // Server framework use ModelStore to store and query models.
// ModelStore stores multiple models because worker could get models of the previous iterations. // ModelStore stores multiple models because worker could get models of the previous iterations.

View File

@ -111,39 +111,6 @@ bool ParameterAggregator::LaunchAggregators() {
return true; return true;
} }
bool ParameterAggregator::LaunchOptimizers() {
for (auto &optimizer_with_params : optimizer_kernel_parameters_) {
KernelParams &params = optimizer_with_params.second;
std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel = optimizer_with_params.first;
MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs);
if (!ret) {
MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed.";
continue;
}
}
// As long as all the optimizer kernels are launched, consider optimizing for this ParameterAggregator as done.
optimizing_done_ = true;
return true;
}
AddressPtr ParameterAggregator::Pull() {
if (memory_register_ == nullptr) {
MS_LOG(ERROR)
<< "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
return nullptr;
}
current_pull_count_++;
if (current_pull_count_ == required_pull_count_) {
pulling_done_ = true;
}
MS_LOG(DEBUG) << "The " << current_pull_count_ << " time of Pull. Pulling done status: " << pulling_done_;
std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
return name_to_addr["weight"];
}
AddressPtr ParameterAggregator::GetWeight() { AddressPtr ParameterAggregator::GetWeight() {
if (memory_register_ == nullptr) { if (memory_register_ == nullptr) {
MS_LOG(ERROR) MS_LOG(ERROR)
@ -193,8 +160,8 @@ bool ParameterAggregator::requires_aggr() const { return requires_aggr_; }
bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) { bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (!JudgeRequiresAggr(cnode)) { if (!JudgeRequiredAggr(cnode)) {
MS_LOG(WARNING) << "Aggregation for weight for kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required."; MS_LOG(WARNING) << "Aggregation for weight of kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required.";
} }
std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode); std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
@ -223,32 +190,12 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
return true; return true;
} }
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &) {
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL || if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) { ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
return true; return true;
} }
MS_EXCEPTION_IF_NULL(cnode);
const std::string &name = AnfAlgo::GetCNodeName(cnode);
auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode);
if (optimizer_kernel == nullptr) {
MS_LOG(EXCEPTION) << "Failed to create optimizer kernel for " << name;
return false;
}
optimizer_kernel->InitKernel(cnode);
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = optimizer_kernel->reuse_kernel_node_inputs_info();
if (!AssignMemory(optimizer_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
return false;
}
if (!GenerateOptimizerKernelParams(optimizer_kernel, memory_register_)) {
MS_LOG(ERROR) << "Generating optimizer kernel parameters failed.";
return false;
}
return true; return true;
} }
@ -323,29 +270,6 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<
return true; return true;
} }
bool ParameterAggregator::GenerateOptimizerKernelParams(
const std::shared_ptr<kernel::OptimizerKernel> &optimizer_kernel,
const std::shared_ptr<MemoryRegister> &memory_register) {
MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
KernelParams optimizer_params = {};
const std::vector<std::string> &input_names = optimizer_kernel->input_names();
(void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
const std::vector<std::string> &workspace_names = optimizer_kernel->workspace_names();
(void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace),
[&](const std::string &name) { return memory_register->addresses()[name]; });
const std::vector<std::string> &output_names = optimizer_kernel->output_names();
(void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs),
[&](const std::string &name) { return memory_register->addresses()[name]; });
optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params));
return true;
}
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) { std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
std::vector<std::string> aggregation_algorithm = {}; std::vector<std::string> aggregation_algorithm = {};
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL || if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
@ -362,7 +286,7 @@ std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const C
return aggregation_algorithm; return aggregation_algorithm;
} }
bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) { bool ParameterAggregator::JudgeRequiredAggr(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
std::string cnode_name = AnfAlgo::GetCNodeName(cnode); std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 || if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
@ -375,7 +299,7 @@ bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(weight_node); MS_EXCEPTION_IF_NULL(weight_node);
if (!weight_node->isa<Parameter>()) { if (!weight_node->isa<Parameter>()) {
MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter node."; MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter.";
return false; return false;
} }
auto param_info = weight_node->cast<ParameterPtr>()->param_info(); auto param_info = weight_node->cast<ParameterPtr>()->param_info();

View File

@ -77,11 +77,6 @@ class ParameterAggregator {
// Launch aggregators/optimizers of this ParameterAggregator in order. // Launch aggregators/optimizers of this ParameterAggregator in order.
bool LaunchAggregators(); bool LaunchAggregators();
bool LaunchOptimizers();
// The implementation for primitive Pull in parameter server training mode.
// Every call of this method will increase the count for pull by 1.
AddressPtr Pull();
// Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing // Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing
// any change of status. // any change of status.
@ -98,7 +93,6 @@ class ParameterAggregator {
bool IsOptimizingDone() const; bool IsOptimizingDone() const;
bool IsPullingDone() const; bool IsPullingDone() const;
// Return whether this parameter requires aggragation.
bool requires_aggr() const; bool requires_aggr() const;
private: private:
@ -119,15 +113,13 @@ class ParameterAggregator {
// memory_register. // memory_register.
bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel, bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
const std::shared_ptr<MemoryRegister> &memory_register); const std::shared_ptr<MemoryRegister> &memory_register);
bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> &optim_kernel,
const std::shared_ptr<MemoryRegister> &memory_register);
// The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user // The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user
// configuration, etc. // configuration, etc.
std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode); std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode);
// Judge whether the parameter needs to be aggregated. // Judge whether the parameter needs to be aggregated.
bool JudgeRequiresAggr(const CNodePtr &cnode); bool JudgeRequiredAggr(const CNodePtr &cnode);
ServerMode server_mode_; ServerMode server_mode_;
size_t required_push_count_; size_t required_push_count_;

View File

@ -38,17 +38,16 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
const FinishIterCb &finish_iteration_cb) { const FinishIterCb &finish_iteration_cb) {
MS_EXCEPTION_IF_NULL(communicator); MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator; communicator_ = communicator;
MS_LOG(INFO) << "Round " << name_ << " start initialize.";
// Register callback for round kernel.
communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) { communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message); MS_ERROR_IF_NULL_WO_RET_VAL(message);
LaunchRoundKernel(message); LaunchRoundKernel(message);
}); });
// Callback when the iteration is finished. // Callback when the iteration is finished.
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void { finish_iteration_cb_ = [this, finish_iteration_cb](bool, const std::string &) -> void {
std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration."; std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
finish_iteration_cb(is_iteration_valid, reason); finish_iteration_cb(true, reason);
}; };
// Callback for finalizing the server. This can only be called once. // Callback for finalizing the server. This can only be called once.
@ -62,9 +61,9 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
MS_EXCEPTION_IF_NULL(iter_timer_); MS_EXCEPTION_IF_NULL(iter_timer_);
// 1.Set the timeout callback for the timer. // 1.Set the timeout callback for the timer.
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void { iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool, const std::string &) -> void {
std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration."; std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration.";
timeout_cb(is_iteration_valid, reason); timeout_cb(false, reason);
}); });
// 2.Stopping timer callback which will be set to the round kernel. // 2.Stopping timer callback which will be set to the round kernel.
@ -139,7 +138,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
return; return;
} }
++Iteration::GetInstance().running_round_num_; (void)(Iteration::GetInstance().running_round_num_++);
AddressPtr input = std::make_shared<Address>(); AddressPtr input = std::make_shared<Address>();
AddressPtr output = std::make_shared<Address>(); AddressPtr output = std::make_shared<Address>();
MS_ERROR_IF_NULL_WO_RET_VAL(input); MS_ERROR_IF_NULL_WO_RET_VAL(input);
@ -167,7 +166,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
reason = "Launching round kernel of round " + name_ + " failed."; reason = "Launching round kernel of round " + name_ + " failed.";
Iteration::GetInstance().NotifyNext(false, reason); Iteration::GetInstance().NotifyNext(false, reason);
} }
--Iteration::GetInstance().running_round_num_; (void)(Iteration::GetInstance().running_round_num_--);
return; return;
} }

View File

@ -32,7 +32,7 @@
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
// The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster managers like K8S. // The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster manager like K8S.
std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr; std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr;
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {}; std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {};
void SignalHandler(int signal) { void SignalHandler(int signal) {
@ -45,7 +45,6 @@ void SignalHandler(int signal) {
MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server); MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server);
(void)g_communicator_with_server->Stop(); (void)g_communicator_with_server->Stop();
return;
} }
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config, void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
@ -68,24 +67,10 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
return; return;
} }
// Each step of the server pipeline may have dependency on other steps, which includes:
// InitServerContext must be the first step to set contexts for later steps.
// Server Running relies on URL or Message Type Register:
// StartCommunicator---->InitIteration
// Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
// RegisterRoundKernel---->StartCommunicator
// Kernel Initialization relies on Executor Initialization:
// RegisterRoundKernel---->InitExecutor
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
// InitCipher---->InitExecutor
void Server::Run() { void Server::Run() {
std::unique_lock<std::mutex> lock(scaling_mtx_); std::unique_lock<std::mutex> lock(scaling_mtx_);
InitServerContext(); InitServerContext();
InitPkiCertificate();
InitCluster(); InitCluster();
InitIteration(); InitIteration();
RegisterCommCallbacks(); RegisterCommCallbacks();
@ -98,6 +83,7 @@ void Server::Run() {
} }
RegisterRoundKernel(); RegisterRoundKernel();
InitMetrics(); InitMetrics();
Recover();
MS_LOG(INFO) << "Server started successfully."; MS_LOG(INFO) << "Server started successfully.";
safemode_ = false; safemode_ = false;
lock.unlock(); lock.unlock();
@ -111,9 +97,27 @@ void Server::Run() {
MS_EXCEPTION_IF_NULL(communicator_with_server_); MS_EXCEPTION_IF_NULL(communicator_with_server_);
communicator_with_server_->Join(); communicator_with_server_->Join();
MsException::Instance().CheckException(); MsException::Instance().CheckException();
func_graph_ = nullptr;
return; return;
} }
void Server::InitPkiCertificate() {
if (ps::PSContext::instance()->pki_verify()) {
root_first_ca_path_ = ps::PSContext::instance()->root_first_ca_path();
root_second_ca_path_ = ps::PSContext::instance()->root_second_ca_path();
equip_crl_path_ = ps::PSContext::instance()->equip_crl_path();
replay_attack_time_diff_ = ps::PSContext::instance()->replay_attack_time_diff();
bool ret = mindspore::ps::server::CertVerify::initRootCertAndCRL(root_first_ca_path_, root_second_ca_path_,
equip_crl_path_, replay_attack_time_diff_);
if (!ret) {
MS_LOG(EXCEPTION) << "init root cert and crl failed.";
return;
}
return;
}
}
void Server::SwitchToSafeMode() { void Server::SwitchToSafeMode() {
MS_LOG(INFO) << "Server switch to safemode."; MS_LOG(INFO) << "Server switch to safemode.";
safemode_ = true; safemode_ = true;
@ -138,11 +142,6 @@ void Server::InitServerContext() {
scheduler_port_ = ps::PSContext::instance()->scheduler_port(); scheduler_port_ = ps::PSContext::instance()->scheduler_port();
worker_num_ = ps::PSContext::instance()->initial_worker_num(); worker_num_ = ps::PSContext::instance()->initial_worker_num();
server_num_ = ps::PSContext::instance()->initial_server_num(); server_num_ = ps::PSContext::instance()->initial_server_num();
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType && server_num_ > 1) {
MS_LOG(EXCEPTION) << "Only single server is supported for PW_ENCRYPT now, but got server_num is:." << server_num_;
return;
}
return; return;
} }
@ -218,6 +217,8 @@ void Server::InitIteration() {
cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold; cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold; cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold; cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
cipher_push_list_sign_cnt_ = cipher_config_.push_list_sign_threshold;
cipher_get_list_sign_cnt_ = cipher_config_.get_list_sign_threshold;
cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold; cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1; cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
cipher_time_window_ = cipher_config_.cipher_time_window; cipher_time_window_ = cipher_config_.cipher_time_window;
@ -228,6 +229,8 @@ void Server::InitIteration() {
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_; << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_ MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
<< " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_ << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
<< " cipher_push_list_sign_cnt_: " << cipher_push_list_sign_cnt_
<< " cipher_get_list_sign_cnt_: " << cipher_get_list_sign_cnt_
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_ << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_ << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
<< " cipher_time_window_: " << cipher_time_window_; << " cipher_time_window_: " << cipher_time_window_;
@ -245,6 +248,7 @@ void Server::InitIteration() {
void Server::InitCipher() { void Server::InitCipher() {
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
cipher_init_ = &armour::CipherInit::GetInstance(); cipher_init_ = &armour::CipherInit::GetInstance();
int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_); int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
unsigned char cipher_p[SECRET_MAX_LEN] = {0}; unsigned char cipher_p[SECRET_MAX_LEN] = {0};
const int cipher_g = 1; const int cipher_g = 1;
@ -258,8 +262,7 @@ void Server::InitCipher() {
param.t = cipher_t; param.t = cipher_t;
int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p)); int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p));
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; MS_LOG(EXCEPTION) << "Memcpy_s error, errorno" << ret;
return;
} }
param.dp_delta = dp_delta; param.dp_delta = dp_delta;
param.dp_eps = dp_eps; param.dp_eps = dp_eps;
@ -281,10 +284,9 @@ void Server::InitCipher() {
if (prim != NULL) { if (prim != NULL) {
BN_clear_free(prim); BN_clear_free(prim);
} }
cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
(void)cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_, cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_push_list_sign_cnt_,
cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_, cipher_get_list_sign_cnt_, cipher_reconstruct_secrets_up_cnt_);
cipher_reconstruct_secrets_up_cnt_);
#endif #endif
} }
@ -366,12 +368,14 @@ void Server::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunic
std::bind(&Server::HandleNewInstanceRequest, this, std::placeholders::_1)); std::bind(&Server::HandleNewInstanceRequest, this, std::placeholders::_1));
communicator->RegisterMsgCallBack("queryInstance", communicator->RegisterMsgCallBack("queryInstance",
std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1)); std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1));
communicator->RegisterMsgCallBack("syncAfterRecover",
std::bind(&Server::HandleSyncAfterRecoveryRequest, this, std::placeholders::_1));
} }
void Server::InitExecutor() { void Server::InitExecutor() {
MS_EXCEPTION_IF_NULL(func_graph_); MS_EXCEPTION_IF_NULL(func_graph_);
if (executor_threshold_ == 0) { if (executor_threshold_ == 0) {
MS_LOG(EXCEPTION) << "The executor's threshold should be greater than 0."; MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
return; return;
} }
// The train engine instance is used in both push-type and pull-type kernels, // The train engine instance is used in both push-type and pull-type kernels,
@ -430,7 +434,6 @@ void Server::StartCommunicator() {
MS_LOG(INFO) << "Start communicator with server."; MS_LOG(INFO) << "Start communicator with server.";
if (!communicator_with_server_->Start()) { if (!communicator_with_server_->Start()) {
MS_LOG(EXCEPTION) << "Starting communicator with server failed."; MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
return;
} }
DistributedMetadataStore::GetInstance().Initialize(server_node_); DistributedMetadataStore::GetInstance().Initialize(server_node_);
CollectiveOpsImpl::GetInstance().Initialize(server_node_); CollectiveOpsImpl::GetInstance().Initialize(server_node_);
@ -447,6 +450,32 @@ void Server::StartCommunicator() {
}); });
} }
void Server::Recover() {
server_recovery_ = std::make_shared<ServerRecovery>();
MS_EXCEPTION_IF_NULL(server_recovery_);
// Try to recovery from persistent storage.
if (!server_recovery_->Initialize(ps::PSContext::instance()->config_file_path())) {
MS_LOG(WARNING) << "Initializing server recovery failed. Do not recover for this server.";
return;
}
if (server_recovery_->Recover()) {
// If this server recovers, need to notify cluster to reach consistency.
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
MS_LOG(INFO) << "Synchronize with leader server after recovery.";
if (!server_recovery_->SyncAfterRecovery(tcp_comm, server_node_->rank_id())) {
MS_LOG(EXCEPTION) << "Failed to reach consistency of the cluster after recovery.";
return;
}
}
// Set the recovery handler to Iteration.
MS_EXCEPTION_IF_NULL(iteration_);
iteration_->set_recovery_handler(server_recovery_);
}
void Server::ProcessBeforeScalingOut() { void Server::ProcessBeforeScalingOut() {
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
iteration_->ScalingBarrier(); iteration_->ScalingBarrier();
@ -510,7 +539,6 @@ void Server::ProcessAfterScalingIn() {
} }
void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
@ -528,7 +556,6 @@ void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHa
} }
void Server::HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) { void Server::HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
@ -552,7 +579,6 @@ void Server::HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHan
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
MS_ERROR_IF_NULL_WO_RET_VAL(message->data());
std::string hyper_params_str(static_cast<const char *>(message->data()), message->len()); std::string hyper_params_str(static_cast<const char *>(message->data()), message->len());
nlohmann::json new_instance_json; nlohmann::json new_instance_json;
nlohmann::json response; nlohmann::json response;
@ -595,6 +621,31 @@ void Server::HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageH
return; return;
} }
} }
void Server::HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
MS_LOG(INFO) << "Receive SyncAfterRecover request from other server.";
std::string response = "success";
if (!tcp_comm->SendResponse(response.c_str(), response.size(), message)) {
MS_LOG(ERROR) << "Sending response of SyncAfterRecoverRequest failed.";
return;
}
if (!safemode_.load()) {
MS_LOG(INFO) << "Need to synchronize for other server's recovery";
SyncAfterRecover sync_after_recovery_req;
(void)sync_after_recovery_req.ParseFromArray(message->data(), SizeToInt(message->len()));
if (!iteration_->SyncAfterRecovery(sync_after_recovery_req.current_iter_num())) {
MS_LOG(ERROR) << "Sync after recovery failed.";
return;
}
}
}
} // namespace server } // namespace server
} // namespace fl } // namespace fl
} // namespace mindspore } // namespace mindspore

View File

@ -24,19 +24,19 @@
#include "ps/core/communicator/tcp_communicator.h" #include "ps/core/communicator/tcp_communicator.h"
#include "ps/core/communicator/task_executor.h" #include "ps/core/communicator/task_executor.h"
#include "ps/core/file_configuration.h" #include "ps/core/file_configuration.h"
#include "fl/server/common.h"
#include "fl/server/executor.h"
#include "fl/server/iteration.h"
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
#include "fl/armour/cipher/cipher_init.h" #include "fl/armour/cipher/cipher_init.h"
#endif #endif
#include "fl/server/common.h"
#include "fl/server/executor.h"
#include "fl/server/iteration.h"
namespace mindspore { namespace mindspore {
namespace fl { namespace fl {
namespace server { namespace server {
// The sleeping time of the server thread before the networking is completed. // The sleeping time of the server thread before the networking is completed.
constexpr uint32_t kServerSleepTimeForNetworking = 1000; constexpr uint32_t kServerSleepTimeForNetworking = 1000;
constexpr uint64_t kDefaultReplayAttackTimeDiff = 60000;
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning. // Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
class Server { class Server {
public: public:
@ -51,6 +51,21 @@ class Server {
// According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
// blocked until the server is finalized. // blocked until the server is finalized.
// func_graph is the frontend graph which will be parse in server's exector and aggregator. // func_graph is the frontend graph which will be parse in server's exector and aggregator.
// Each step of the server pipeline may have dependency on other steps, which includes:
// InitServerContext must be the first step to set contexts for later steps.
// Server Running relies on URL or Message Type Register:
// StartCommunicator---->InitIteration
// Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
// RegisterRoundKernel---->StartCommunicator
// Kernel Initialization relies on Executor Initialization:
// RegisterRoundKernel---->InitExecutor
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
// InitCipher---->InitExecutor
void Run(); void Run();
void SwitchToSafeMode(); void SwitchToSafeMode();
@ -74,17 +89,25 @@ class Server {
communicators_with_worker_({}), communicators_with_worker_({}),
iteration_(nullptr), iteration_(nullptr),
safemode_(true), safemode_(true),
server_recovery_(nullptr),
scheduler_ip_(""), scheduler_ip_(""),
scheduler_port_(0), scheduler_port_(0),
server_num_(0), server_num_(0),
worker_num_(0), worker_num_(0),
fl_server_port_(0), fl_server_port_(0),
pki_verify_(false),
root_first_ca_path_(""),
root_second_ca_path_(""),
equip_crl_path_(""),
replay_attack_time_diff_(kDefaultReplayAttackTimeDiff),
cipher_initial_client_cnt_(0), cipher_initial_client_cnt_(0),
cipher_exchange_keys_cnt_(0), cipher_exchange_keys_cnt_(0),
cipher_get_keys_cnt_(0), cipher_get_keys_cnt_(0),
cipher_share_secrets_cnt_(0), cipher_share_secrets_cnt_(0),
cipher_get_secrets_cnt_(0), cipher_get_secrets_cnt_(0),
cipher_get_clientlist_cnt_(0), cipher_get_clientlist_cnt_(0),
cipher_push_list_sign_cnt_(0),
cipher_get_list_sign_cnt_(0),
cipher_reconstruct_secrets_up_cnt_(0), cipher_reconstruct_secrets_up_cnt_(0),
cipher_reconstruct_secrets_down_cnt_(0), cipher_reconstruct_secrets_down_cnt_(0),
cipher_time_window_(0) {} cipher_time_window_(0) {}
@ -95,9 +118,6 @@ class Server {
// Load variables which is set by ps_context. // Load variables which is set by ps_context.
void InitServerContext(); void InitServerContext();
// Try to recover server config from persistent storage.
void Recovery();
// Initialize the server cluster, server node and communicators. // Initialize the server cluster, server node and communicators.
void InitCluster(); void InitCluster();
bool InitCommunicatorWithServer(); bool InitCommunicatorWithServer();
@ -130,6 +150,12 @@ class Server {
// The communicators should be started after all initializations are completed. // The communicators should be started after all initializations are completed.
void StartCommunicator(); void StartCommunicator();
// Try to recover server config from persistent storage.
void Recover();
// load pki huks cbg root certificate and crl
void InitPkiCertificate();
// The barriers before scaling operations. // The barriers before scaling operations.
void ProcessBeforeScalingOut(); void ProcessBeforeScalingOut();
void ProcessBeforeScalingIn(); void ProcessBeforeScalingIn();
@ -148,6 +174,9 @@ class Server {
// Query current instance information. // Query current instance information.
void HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message); void HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// Synchronize after recovery is completed to ensure consistency.
void HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
// The server node is initialized in Server. // The server node is initialized in Server.
std::shared_ptr<ps::core::ServerNode> server_node_; std::shared_ptr<ps::core::ServerNode> server_node_;
@ -179,7 +208,7 @@ class Server {
// communicators. // communicators.
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_; std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
// Mutex for scaling operations. We must wait server's initialization done before handle scaling events. // Mutex for scaling operations.
std::mutex scaling_mtx_; std::mutex scaling_mtx_;
// Iteration consists of multiple kinds of rounds. // Iteration consists of multiple kinds of rounds.
@ -189,21 +218,33 @@ class Server {
// If true, the server is not available to workers and clients. // If true, the server is not available to workers and clients.
std::atomic_bool safemode_; std::atomic_bool safemode_;
// The recovery object for server.
std::shared_ptr<ServerRecovery> server_recovery_;
// Variables set by ps context. // Variables set by ps context.
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
armour::CipherInit *cipher_init_{nullptr}; armour::CipherInit *cipher_init_;
#endif #endif
std::string scheduler_ip_; std::string scheduler_ip_;
uint16_t scheduler_port_; uint16_t scheduler_port_;
uint32_t server_num_; uint32_t server_num_;
uint32_t worker_num_; uint32_t worker_num_;
uint16_t fl_server_port_; uint16_t fl_server_port_;
bool pki_verify_;
std::string root_first_ca_path_;
std::string root_second_ca_path_;
std::string equip_crl_path_;
uint64_t replay_attack_time_diff_;
size_t cipher_initial_client_cnt_; size_t cipher_initial_client_cnt_;
size_t cipher_exchange_keys_cnt_; size_t cipher_exchange_keys_cnt_;
size_t cipher_get_keys_cnt_; size_t cipher_get_keys_cnt_;
size_t cipher_share_secrets_cnt_; size_t cipher_share_secrets_cnt_;
size_t cipher_get_secrets_cnt_; size_t cipher_get_secrets_cnt_;
size_t cipher_get_clientlist_cnt_; size_t cipher_get_clientlist_cnt_;
size_t cipher_push_list_sign_cnt_;
size_t cipher_get_list_sign_cnt_;
size_t cipher_reconstruct_secrets_up_cnt_; size_t cipher_reconstruct_secrets_up_cnt_;
size_t cipher_reconstruct_secrets_down_cnt_; size_t cipher_reconstruct_secrets_down_cnt_;
uint64_t cipher_time_window_; uint64_t cipher_time_window_;

View File

@ -0,0 +1,114 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fl/server/server_recovery.h"
#include "fl/server/local_meta_store.h"
#include "debug/common.h"
namespace mindspore {
namespace fl {
namespace server {
bool ServerRecovery::Initialize(const std::string &config_file) {
config_ = std::make_unique<ps::core::FileConfiguration>(config_file);
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(EXCEPTION) << "Initializing for server recovery failed. Config file path " << config_file
<< " may be invalid or not exist.";
return false;
}
// Read the server recovery file path.
if (!config_->Exists(kServerRecovery)) {
MS_LOG(WARNING) << "Server recovery config is not set. This node doesn't support recovery.";
return true;
} else {
std::string value = config_->Get(kServerRecovery, "");
nlohmann::json value_json;
try {
value_json = nlohmann::json::parse(value);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "The data is not in json format.";
return false;
}
// Parse the storage type.
uint32_t storage_type = JsonGetKeyWithException<uint32_t>(value_json, ps::kStoreType);
if (std::to_string(storage_type) != ps::kFileStorage) {
MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported.";
return false;
}
// Parse storage file path.
server_recovery_file_path_ = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
MS_LOG(INFO) << "Server recovery file path is " << server_recovery_file_path_;
}
return true;
}
bool ServerRecovery::Recover() {
server_recovery_file_.open(server_recovery_file_path_, std::ios::in);
if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) {
MS_LOG(WARNING) << "Can't open server recovery file " << server_recovery_file_path_;
return false;
}
nlohmann::json server_recovery_json;
try {
server_recovery_json = nlohmann::json::parse(server_recovery_file_);
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "The server recovery file is not in json format.";
return false;
}
uint64_t current_iter = JsonGetKeyWithException<uint64_t>(server_recovery_json, kCurrentIteration);
LocalMetaStore::GetInstance().set_curr_iter_num(current_iter);
MS_LOG(INFO) << "Recover from persistent storage: current iteration number is " << current_iter;
server_recovery_file_.close();
return true;
}
bool ServerRecovery::Save(uint64_t current_iter) {
server_recovery_file_.open(server_recovery_file_path_, std::ios::out | std::ios::ate);
if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) {
MS_LOG(WARNING) << "Can't save data to recovery file " << server_recovery_file_path_
<< ". This file path is invalid or does not exit.";
return false;
}
nlohmann::json server_metadata_json;
server_metadata_json[kCurrentIteration] = current_iter;
server_recovery_file_ << server_metadata_json;
server_recovery_file_.close();
return true;
}
bool ServerRecovery::SyncAfterRecovery(const std::shared_ptr<ps::core::TcpCommunicator> &communicator,
uint32_t rank_id) {
// If this server is follower server, notify leader server that this server has recovered.
if (rank_id != kLeaderServerRank) {
MS_ERROR_IF_NULL_W_RET_VAL(communicator, false);
SyncAfterRecover sync_after_recover_req;
sync_after_recover_req.set_current_iter_num(LocalMetaStore::GetInstance().curr_iter_num());
if (!communicator->SendPbRequest(sync_after_recover_req, kLeaderServerRank,
ps::core::TcpUserCommand::kSyncAfterRecover)) {
MS_LOG(ERROR) << "Sending sync after recovery message to leader server failed.";
return false;
}
}
return true;
}
} // namespace server
} // namespace fl
} // namespace mindspore

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_RECOVERY_H_
#define MINDSPORE_CCSRC_FL_SERVER_SERVER_RECOVERY_H_
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "ps/core/recovery_base.h"
#include "ps/core/file_configuration.h"
#include "ps/core/communicator/tcp_communicator.h"
#include "ps/ps_context.h"
namespace mindspore {
namespace fl {
namespace server {
constexpr auto kServerRecovery = "server_recovery";
// The class helps server node to do recovery operation.
// Different from the recovery process in ps/core/node_recovery.*, this class focus on recovery of the server data. For
// example, current iteration number, learning rate, etc.
class ServerRecovery : public ps::core::RecoveryBase {
public:
ServerRecovery() : config_(nullptr), server_recovery_file_path_("") {}
~ServerRecovery() override = default;
bool Initialize(const std::string &config_file) override;
bool Recover() override;
// Save server's metadata to persistent storage.
bool Save(uint64_t current_iter);
// If this server recovers, need to notify cluster to reach consistency.
bool SyncAfterRecovery(const std::shared_ptr<ps::core::TcpCommunicator> &communicator, uint32_t rank_id);
private:
// This is the main config file set by ps context.
std::unique_ptr<ps::core::FileConfiguration> config_;
// The server recovery file path.
std::string server_recovery_file_path_;
// The server recovery file object.
std::fstream server_recovery_file_;
};
} // namespace server
} // namespace fl
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_RECOVERY_H_

View File

@ -220,9 +220,9 @@ void FLWorker::InitializeFollowerScaler() {
std::bind(&FLWorker::ProcessAfterScalingOut, this)); std::bind(&FLWorker::ProcessAfterScalingOut, this));
worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline", worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
std::bind(&FLWorker::ProcessAfterScalingIn, this)); std::bind(&FLWorker::ProcessAfterScalingIn, this));
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning), worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::UserDefineEvent::kIterationRunning),
std::bind(&FLWorker::HandleIterationRunningEvent, this)); std::bind(&FLWorker::HandleIterationRunningEvent, this));
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted), worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::UserDefineEvent::kIterationCompleted),
std::bind(&FLWorker::HandleIterationCompletedEvent, this)); std::bind(&FLWorker::HandleIterationCompletedEvent, this));
} }

View File

@ -51,7 +51,7 @@ constexpr uint32_t kWorkerSleepTimeForNetworking = 1000;
// The time duration between retrying when server is in safemode. // The time duration between retrying when server is in safemode.
constexpr uint32_t kWorkerRetryDurationForSafeMode = 500; constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
// The rank of the leader server. // The leader server rank.
constexpr uint32_t kLeaderServerRank = 0; constexpr uint32_t kLeaderServerRank = 0;
// The timeout for worker sending message to server in case of network jitter. // The timeout for worker sending message to server in case of network jitter.

View File

@ -879,6 +879,10 @@ bool StartServerAction(const ResourcePtr &res) {
std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold); std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold);
size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)), size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)),
reconstruct_secrets_threshold); reconstruct_secrets_threshold);
size_t push_list_sign_threshold = std::max(
static_cast<size_t>(std::ceil(client_list_threshold * share_secrets_ratio)), reconstruct_secrets_threshold);
size_t get_list_sign_threshold = std::max(
static_cast<size_t>(std::ceil(push_list_sign_threshold * share_secrets_ratio)), reconstruct_secrets_threshold);
#ifdef ENABLE_ARMOUR #ifdef ENABLE_ARMOUR
std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
if (encrypt_type == ps::kPWEncryptType) { if (encrypt_type == ps::kPWEncryptType) {
@ -889,6 +893,10 @@ bool StartServerAction(const ResourcePtr &res) {
rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold}); rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold});
rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold}); rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold}); rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
if (ps::PSContext::instance()->pki_verify()) {
rounds_config.push_back({"pushListSign", true, cipher_time_window, true, push_list_sign_threshold});
rounds_config.push_back({"getListSign", true, cipher_time_window, true, get_list_sign_threshold});
}
} }
if (encrypt_type == ps::kStablePWEncryptType) { if (encrypt_type == ps::kStablePWEncryptType) {
MS_LOG(INFO) << "Add stable secure aggregation rounds."; MS_LOG(INFO) << "Add stable secure aggregation rounds.";
@ -898,7 +906,8 @@ bool StartServerAction(const ResourcePtr &res) {
#endif #endif
fl::server::CipherConfig cipher_config = { fl::server::CipherConfig cipher_config = {
share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold, share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,
share_secrets_threshold, get_secrets_threshold, client_list_threshold, reconstruct_secrets_threshold}; share_secrets_threshold, get_secrets_threshold, client_list_threshold, push_list_sign_threshold,
get_list_sign_threshold, reconstruct_secrets_threshold};
size_t executor_threshold = 0; size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {

View File

@ -403,6 +403,7 @@ PYBIND11_MODULE(_c_expression, m) {
"Set threshold count ratio for share secrets round.") "Set threshold count ratio for share secrets round.")
.def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.") .def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.")
.def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.") .def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.")
.def("cipher_time_window", &PSContext::cipher_time_window, "Get time window for cipher rounds.")
.def("set_reconstruct_secrets_threshold", &PSContext::set_reconstruct_secrets_threshold, .def("set_reconstruct_secrets_threshold", &PSContext::set_reconstruct_secrets_threshold,
"Set threshold count for reconstruct secrets round.") "Set threshold count for reconstruct secrets round.")
.def("reconstruct_secrets_threshold", &PSContext::reconstruct_secrets_threshold, .def("reconstruct_secrets_threshold", &PSContext::reconstruct_secrets_threshold,
@ -416,16 +417,38 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.") .def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
.def("client_batch_size", &PSContext::client_batch_size, "Get federated learning client batch size.") .def("client_batch_size", &PSContext::client_batch_size, "Get federated learning client batch size.")
.def("set_client_learning_rate", &PSContext::set_client_learning_rate, .def("set_client_learning_rate", &PSContext::set_client_learning_rate,
"Set worker's standalone training step number before communicating with server.") "Set federated learning client learning rate.")
.def("client_learning_rate", &PSContext::client_learning_rate, .def("client_learning_rate", &PSContext::client_learning_rate,
"Get worker's standalone training step number before communicating with server.") "Get worker's standalone training step number before communicating with server.")
.def("set_worker_step_num_per_iteration", &PSContext::set_worker_step_num_per_iteration, .def("set_worker_step_num_per_iteration", &PSContext::set_worker_step_num_per_iteration,
"Set federated learning client learning rate.") "Set worker's standalone training step number before communicating with server..")
.def("worker_step_num_per_iteration", &PSContext::worker_step_num_per_iteration, .def("worker_step_num_per_iteration", &PSContext::worker_step_num_per_iteration,
"Get federated learning client learning rate.") "Get federated learning client learning rate.")
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
"Set federated learning client using secure aggregation.")
.def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.")
.def("dp_eps", &PSContext::dp_eps, "Get dp epsilon for federated learning secure aggregation.")
.def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.")
.def("dp_delta", &PSContext::dp_delta, "Get dp delta for federated learning secure aggregation.")
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
"Set dp norm clip for federated learning secure aggregation.")
.def("dp_norm_clip", &PSContext::dp_norm_clip, "Get dp norm clip for federated learning secure aggregation.")
.def("set_encrypt_type", &PSContext::set_encrypt_type,
"Set encrypt type for federated learning secure aggregation.")
.def("encrypt_type", &PSContext::encrypt_type, "Get encrypt type for federated learning secure aggregation.")
.def("set_root_first_ca_path", &PSContext::set_root_first_ca_path, "Set root first ca path.")
.def("root_first_ca_path", &PSContext::root_first_ca_path, "Get root first ca path.")
.def("set_root_second_ca_path", &PSContext::set_root_second_ca_path, "Set root second ca path.")
.def("root_second_ca_path", &PSContext::root_second_ca_path, "Get root second ca path.")
.def("set_pki_verify", &PSContext::set_pki_verify, "Set pki verify.")
.def("pki_verify", &PSContext::pki_verify, "Get pki verify.")
.def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port, .def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port,
"Set scheduler manage port used to scale out/in.") "Set scheduler manage port used to scale out/in.")
.def("scheduler_manage_port", &PSContext::scheduler_manage_port, "Get scheduler manage port used to scale out/in.") .def("scheduler_manage_port", &PSContext::scheduler_manage_port, "Get scheduler manage port used to scale out/in.")
.def("set_equip_crl_path", &PSContext::set_equip_crl_path, "Set root second crl path.")
.def("set_replay_attack_time_diff", &PSContext::set_replay_attack_time_diff, "Set replay attack time diff.")
.def("equip_crl_path", &PSContext::equip_crl_path, "Get root second crl path.")
.def("replay_attack_time_diff", &PSContext::replay_attack_time_diff, "Get replay attack time diff.")
.def("set_enable_ssl", &PSContext::set_enable_ssl, "Set PS SSL mode enabled or disabled.") .def("set_enable_ssl", &PSContext::set_enable_ssl, "Set PS SSL mode enabled or disabled.")
.def("enable_ssl", &PSContext::enable_ssl, "Get PS SSL mode enabled or disabled.") .def("enable_ssl", &PSContext::enable_ssl, "Get PS SSL mode enabled or disabled.")
.def("set_client_password", &PSContext::set_client_password, "Set the client password to decode the p12 file.") .def("set_client_password", &PSContext::set_client_password, "Set the client password to decode the p12 file.")
@ -441,7 +464,9 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip, .def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
"Set dp norm clip for federated learning secure aggregation.") "Set dp norm clip for federated learning secure aggregation.")
.def("set_encrypt_type", &PSContext::set_encrypt_type, .def("set_encrypt_type", &PSContext::set_encrypt_type,
"Set encrypt type for federated learning secure aggregation."); "Set encrypt type for federated learning secure aggregation.")
.def("set_http_url_prefix", &PSContext::set_http_url_prefix, "Set http url prefix for http communication.")
.def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.");
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data."); (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data."); (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");

View File

@ -76,28 +76,15 @@ constexpr int64_t kPullCmd = 51;
constexpr size_t kInvalidKey = UINT64_MAX; constexpr size_t kInvalidKey = UINT64_MAX;
constexpr int64_t kInvalidID = -1; constexpr int64_t kInvalidID = -1;
constexpr int64_t kGradIndex = 0;
constexpr int64_t kIndiceIndex = 1;
constexpr int64_t kFirstDimSize = 2;
constexpr int64_t kOutDimSize = 3;
constexpr int64_t kBase = 10;
constexpr float kStdDev = 0.01;
constexpr int64_t kSparseLazyAdamIndex = 2;
constexpr int64_t kSparseFtrlIndex = 3;
constexpr int64_t kSparseGradIndex = 6;
constexpr int64_t kSparseIndiceIndex = 7;
constexpr int64_t kHeartbeatTimes = 2;
constexpr int64_t kGradValue = -100;
constexpr uint32_t kMaxMessageSize = static_cast<uint32_t>(100 * (uint32_t(1) << 20)); constexpr uint32_t kMaxMessageSize = static_cast<uint32_t>(100 * (uint32_t(1) << 20));
constexpr char kServerNum[] = "server_num"; constexpr char kServerNum[] = "server_num";
constexpr char kWorkerNum[] = "worker_num"; constexpr char kWorkerNum[] = "worker_num";
constexpr char kNodesIds[] = "node_ids"; constexpr char kNodesIds[] = "node_ids";
constexpr char kNodeId[] = "node_id"; constexpr char kNodeId[] = "node_id";
constexpr char kSuccessCode[] = "0";
constexpr char kErrorCode[] = "1";
constexpr int64_t kSubmitTaskIntervalInMs = 1; constexpr int64_t kSubmitTaskIntervalInMs = 1;
constexpr int64_t kMaxTaskNum = 10240; constexpr int64_t kMaxTaskNum = 10240;
constexpr int64_t kSubmitTimeOutInMs = 30000; constexpr int64_t kSubmitTimeOutInMs = 30000;
@ -105,7 +92,13 @@ constexpr int64_t kRetryCount = 60;
constexpr int64_t kRetryIntervalInMs = 10; constexpr int64_t kRetryIntervalInMs = 10;
constexpr int64_t kThreadNum = 32; constexpr int64_t kThreadNum = 32;
constexpr int64_t kGradIndex = 0;
constexpr int64_t kIndiceIndex = 1;
constexpr int64_t kFirstDimSize = 2;
constexpr int64_t kOutDimSize = 3;
constexpr int64_t kBase = 10;
constexpr float kStdDev = 0.01;
// The timeout period for the scale in node to send the finish message to scheduler. // The timeout period for the scale in node to send the finish message to scheduler.
constexpr uint32_t kScaleInTimeoutInSenconds = 30; constexpr uint32_t kScaleInTimeoutInSenconds = 30;
// The number of retries to determine whether all nodes are successfully registered. // The number of retries to determine whether all nodes are successfully registered.
@ -113,10 +106,26 @@ constexpr uint32_t kCheckRegisteredRetryCount = 30;
// The timeout interval for judging whether all nodes are successfully registered. // The timeout interval for judging whether all nodes are successfully registered.
constexpr uint32_t kCheckRegisteredIntervalInMs = 1000; constexpr uint32_t kCheckRegisteredIntervalInMs = 1000;
// The barrier function which should be called before doing scaling out/in operations.
// It's easy for us to scale out/in nodes after one iteration is completed and keep consistent.
using BarrierBeforeScaleOut = std::function<void(void)>;
using BarrierBeforeScaleIn = std::function<void(void)>;
constexpr int64_t kSparseLazyAdamIndex = 2;
constexpr int64_t kSparseFtrlIndex = 3;
constexpr int64_t kSparseGradIndex = 6;
constexpr int64_t kSparseIndiceIndex = 7;
constexpr int64_t kHeartbeatTimes = 2;
constexpr int64_t kGradValue = -100;
// Whether to support recovery.
constexpr char kIsRecovery[] = "is_recovery";
// The type of persistent storage, currently only supports file storage. // The type of persistent storage, currently only supports file storage.
constexpr char kStoreType[] = "storage_type"; constexpr char kStoreType[] = "storage_type";
// The file used to storage metadata. // The file used to storage metadata.
constexpr char kStoreFilePath[] = "storage_file_path"; constexpr char kStoreFilePath[] = "storage_file_path";
// The file used to storage scheduler metadata.
constexpr char kSchedulerStoreFilePath[] = "scheduler_storage_file_path";
// 1 indicates that the persistent storage type is file. // 1 indicates that the persistent storage type is file.
constexpr char kFileStorage[] = "1"; constexpr char kFileStorage[] = "1";
// The recovery key of json_config. // The recovery key of json_config.
@ -125,6 +134,10 @@ constexpr char kRecoveryWorkerNum[] = "worker_num";
constexpr char kRecoveryServerNum[] = "server_num"; constexpr char kRecoveryServerNum[] = "server_num";
constexpr char kRecoverySchedulerIp[] = "scheduler_ip"; constexpr char kRecoverySchedulerIp[] = "scheduler_ip";
constexpr char kRecoverySchedulerPort[] = "scheduler_port"; constexpr char kRecoverySchedulerPort[] = "scheduler_port";
constexpr char kRecoveryTotalNodeNum[] = "total_node_num";
constexpr char kRecoveryNextWorkerRankId[] = "next_worker_rank_id";
constexpr char kRecoveryNextServerRankId[] = "next_server_rank_id";
constexpr char kRecoveryRegisteredNodesInfos[] = "node_ids";
constexpr char kServerCertPath[] = "server_cert_path"; constexpr char kServerCertPath[] = "server_cert_path";
constexpr char kServerPassword[] = "server_password"; constexpr char kServerPassword[] = "server_password";
@ -132,7 +145,6 @@ constexpr char kCrlPath[] = "crl_path";
constexpr char kClientCertPath[] = "client_cert_path"; constexpr char kClientCertPath[] = "client_cert_path";
constexpr char kClientPassword[] = "client_password"; constexpr char kClientPassword[] = "client_password";
constexpr char kCaCertPath[] = "ca_cert_path"; constexpr char kCaCertPath[] = "ca_cert_path";
constexpr char kCipherList[] = "cipher_list"; constexpr char kCipherList[] = "cipher_list";
constexpr char kCertCheckInterval[] = "cert_check_interval_in_hour"; constexpr char kCertCheckInterval[] = "cert_check_interval_in_hour";
// 7 * 24 // 7 * 24
@ -154,6 +166,7 @@ constexpr int64_t kMaxWarningTime = 180;
constexpr int64_t kLength = 100; constexpr int64_t kLength = 100;
constexpr int64_t kMaxPort = 65535; constexpr int64_t kMaxPort = 65535;
constexpr int64_t kSecurityLevel = 3;
constexpr char kTcpCommunicator[] = "TCP"; constexpr char kTcpCommunicator[] = "TCP";
constexpr char kHttpCommunicator[] = "HTTP"; constexpr char kHttpCommunicator[] = "HTTP";
@ -162,35 +175,14 @@ constexpr char kServerCert[] = "server.p12";
constexpr char kClientCert[] = "client.p12"; constexpr char kClientCert[] = "client.p12";
constexpr char kCaCert[] = "ca.crt"; constexpr char kCaCert[] = "ca.crt";
constexpr char kColon = ':'; constexpr char kColon = ':';
const std::map<std::string, size_t> kCiphers = {{"ECDHE-RSA-AES128-GCM-SHA256", 0}, const std::map<std::string, size_t> kCiphers = {
{"ECDHE-ECDSA-AES128-GCM-SHA256", 1}, {"ECDHE-RSA-AES128-GCM-SHA256", 0}, {"ECDHE-ECDSA-AES128-GCM-SHA256", 1}, {"ECDHE-RSA-AES256-GCM-SHA384", 2},
{"ECDHE-RSA-AES256-GCM-SHA384", 2}, {"ECDHE-ECDSA-AES256-GCM-SHA384", 3}, {"DHE-RSA-AES128-GCM-SHA256", 4}, {"DHE-DSS-AES128-GCM-SHA256", 5},
{"ECDHE-ECDSA-AES256-GCM-SHA384", 3}, {"DHE-RSA-AES256-GCM-SHA384", 6}, {"DHE-DSS-AES256-GCM-SHA384", 7}, {"DHE-PSK-AES128-GCM-SHA256", 8},
{"DHE-RSA-AES128-GCM-SHA256", 4}, {"DHE-PSK-AES256-GCM-SHA384", 9}, {"DHE-PSK-CHACHA20-POLY1305", 10}, {"ECDHE-RSA-CHACHA20-POLY1305", 11},
{"DHE-DSS-AES128-GCM-SHA256", 5}, {"ECDHE-PSK-CHACHA20-POLY1305", 12}, {"DHE-RSA-AES128-CCM", 13}, {"DHE-RSA-AES256-CCM", 14},
{"ECDHE-RSA-AES128-SHA256", 6}, {"DHE-RSA-CHACHA20-POLY1305", 15}, {"DHE-PSK-AES128-CCM", 16}, {"DHE-PSK-AES256-CCM", 17},
{"ECDHE-ECDSA-AES128-SHA256", 7}, {"ECDHE-ECDSA-AES128-CCM", 18}, {"ECDHE-ECDSA-AES256-CCM", 19}, {"ECDHE-ECDSA-CHACHA20-POLY1305", 20}};
{"ECDHE-RSA-AES128-SHA", 8},
{"ECDHE-ECDSA-AES128-SHA", 9},
{"ECDHE-RSA-AES256-SHA384", 10},
{"ECDHE-ECDSA-AES256-SHA384", 11},
{"ECDHE-RSA-AES256-SHA", 12},
{"ECDHE-ECDSA-AES256-SHA", 13},
{"DHE-RSA-AES128-SHA256", 14},
{"DHE-RSA-AES128-SHA", 15},
{"DHE-DSS-AES128-SHA256", 16},
{"DHE-RSA-AES256-SHA256", 17},
{"DHE-DSS-AES256-SHA", 18},
{"DHE-RSA-AES256-SHA", 19},
{"!aNULL", 20},
{"!eNULL", 21},
{"!EXPORT", 22},
{"!DES", 23},
{"!RC4", 24},
{"!3DES", 25},
{"!MD5", 26},
{"!PSK", 27},
{"kEDH+AESGCM", 28}};
#ifdef __APPLE__ #ifdef __APPLE__
using DataPtr = std::shared_ptr<unsigned char>; using DataPtr = std::shared_ptr<unsigned char>;
@ -262,7 +254,7 @@ using HandlerAfterScaleIn = std::function<void(void)>;
constexpr char kClusterSafeMode[] = "The cluster is in safemode."; constexpr char kClusterSafeMode[] = "The cluster is in safemode.";
constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished."; constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished.";
enum class CustomEvent { kIterationRunning = 0, kIterationCompleted }; enum class UserDefineEvent { kIterationRunning = 0, kIterationCompleted, kNodeTimeout };
#define EXC_IF_VEC_IDX_OOB(vec, idx) \ #define EXC_IF_VEC_IDX_OOB(vec, idx) \
{ \ { \

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -50,9 +50,18 @@ void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta,
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
RegisterRespMessage register_resp_message; RegisterRespMessage register_resp_message;
CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size))); CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size)));
MS_LOG(INFO) << "The node id get from scheduler is:" << register_resp_message.node_id()
<< ", rank_id is:" << register_resp_message.rank_id();
if (register_resp_message.node_id() != node_info_.node_id_) { if (register_resp_message.node_id() != node_info_.node_id_) {
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() MS_LOG(ERROR) << "The node id received:" << register_resp_message.node_id()
<< " is not match the current node id:" << node_info_.node_id_; << " is not match the current node id:" << node_info_.node_id_;
return;
}
node_info_.rank_id_ = register_resp_message.rank_id();
if (node_info_.rank_id_ == UINT32_MAX) {
MS_LOG(ERROR) << "The rank id received:" << register_resp_message.rank_id();
return;
} }
// Receive the Register message, indicating that the scheduler is alive, so update the time point at which the // Receive the Register message, indicating that the scheduler is alive, so update the time point at which the
@ -70,7 +79,7 @@ bool AbstractNode::Broadcast(const NodeRole &node_role, const DataPtr &message,
} }
uint32_t broadcast_size = 0; uint32_t broadcast_size = 0;
std::for_each(nodes_address_.begin(), nodes_address_.end(), [&broadcast_size, &node_role](const auto &addr) { (void)std::for_each(nodes_address_.begin(), nodes_address_.end(), [&broadcast_size, &node_role](const auto &addr) {
if (addr.first.first == node_role) { if (addr.first.first == node_role) {
++broadcast_size; ++broadcast_size;
} }
@ -160,19 +169,24 @@ void AbstractNode::BroadcastEvent(const uint32_t &event) {
MS_EXCEPTION_IF_NULL(message_meta); MS_EXCEPTION_IF_NULL(message_meta);
message_meta->set_cmd(NodeCommand::SEND_EVENT); message_meta->set_cmd(NodeCommand::SEND_EVENT);
EventMessage event_message; EventRespMessage event_resp_message;
event_message.set_event(event); event_resp_message.set_event(event);
event_message.set_node_id(node_info_.node_id_);
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF, event_message.SerializeAsString().data(), for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
event_message.ByteSizeLong())) { const uint32_t rank_id = (*it).first.second;
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) const NodeRole role = (*it).first.first;
<< " the node id:" << node_info_.node_id_ << " send event timeout!"; auto client = GetOrCreateTcpClient(rank_id, role);
return; if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(),
event_resp_message.ByteSizeLong())) {
MS_LOG(ERROR) << "send event to node role:" << CommUtil::NodeRoleToString(role) << ", rank id:" << rank_id
<< " timeout!";
} else {
MS_LOG(INFO) << "send event to node role:" << CommUtil::NodeRoleToString(role) << ", rank id:" << rank_id
<< " successful!";
}
} }
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< " the node id:" << node_info_.node_id_ << "is send event to scheduler!"; << " the node id:" << node_info_.node_id_ << "is send event to server/worker!";
} }
void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) { void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) {
@ -185,10 +199,6 @@ void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const Even
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
int command, const uint32_t &timeout) { int command, const uint32_t &timeout) {
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
return false;
}
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) { if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_ MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
@ -209,11 +219,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, cons
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command, const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command,
const uint32_t &timeout) { const uint32_t &timeout) {
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
return false;
}
uint64_t request_id = AddMessageTrack(data.size()); uint64_t request_id = AddMessageTrack(data.size());
if (rank_ids.size() != data.size() || rank_ids.size() != lens.size()) { if (rank_ids.size() != data.size() || rank_ids.size() != lens.size()) {
@ -248,10 +253,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
int command, VectorPtr *output, const uint32_t &timeout) { int command, VectorPtr *output, const uint32_t &timeout) {
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
return false;
}
MS_EXCEPTION_IF_NULL(message); MS_EXCEPTION_IF_NULL(message);
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) { if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
@ -289,10 +290,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, cons
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command, const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command,
std::vector<VectorPtr> *output, const uint32_t &timeout) { std::vector<VectorPtr> *output, const uint32_t &timeout) {
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
return false;
}
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
uint64_t request_id = AddMessageTrack(data.size()); uint64_t request_id = AddMessageTrack(data.size());
@ -493,10 +490,6 @@ std::shared_ptr<CommunicatorBase> AbstractNode::GetOrCreateTcpComm(const std::st
if (!communicators_.count(kTcpCommunicator)) { if (!communicators_.count(kTcpCommunicator)) {
MS_LOG(INFO) << "Create Tcp communicator."; MS_LOG(INFO) << "Create Tcp communicator.";
auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this); auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this);
PSContext::instance()->cluster_config().scheduler_host = scheduler_ip;
PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port);
PSContext::instance()->cluster_config().initial_worker_num = worker_num;
PSContext::instance()->cluster_config().initial_server_num = server_num;
MS_EXCEPTION_IF_NULL(tcp_comm); MS_EXCEPTION_IF_NULL(tcp_comm);
PSContext::instance()->cluster_config().scheduler_host = scheduler_ip; PSContext::instance()->cluster_config().scheduler_host = scheduler_ip;
PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port); PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port);
@ -521,13 +514,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
<< ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
if (CheckSchedulerTimeout()) { if (CheckSchedulerTimeout()) {
MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(WARNING) << "Scheduler is Timeout, please recovery.";
<< ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!";
is_finish_ = true;
wait_finish_cond_.notify_all();
if (!is_already_stopped_) {
OnEventCallback(ClusterEvent::SCHEDULER_TIMEOUT);
}
} }
} else { } else {
UpdateSchedulerTime(); UpdateSchedulerTime();
@ -549,6 +536,7 @@ bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
HeartbeatMessage heartbeat_message; HeartbeatMessage heartbeat_message;
heartbeat_message.set_node_id(node_info_.node_id_); heartbeat_message.set_node_id(node_info_.node_id_);
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " Send heartbeat!";
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(), if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
heartbeat_message.ByteSizeLong(), kCommTimeoutInSeconds)) { heartbeat_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
@ -580,20 +568,31 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
HeartbeatRespMessage heartbeat_resp_message; HeartbeatRespMessage heartbeat_resp_message;
CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size))); CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
if (heartbeat_resp_message.cluster_state() != current_cluster_state_) {
MS_LOG(INFO) << "cluster change state from:" << CommUtil::ClusterStateToString(current_cluster_state_) << " to "
<< CommUtil::ClusterStateToString(heartbeat_resp_message.cluster_state());
}
current_cluster_state_ = heartbeat_resp_message.cluster_state(); current_cluster_state_ = heartbeat_resp_message.cluster_state();
MS_LOG(DEBUG) << "The current cluster state from heartbeat:" MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
<< CommUtil::ClusterStateToString(current_cluster_state_); << CommUtil::ClusterStateToString(current_cluster_state_);
std::string timeoutNodeId;
all_nodes_info_.clear(); all_nodes_info_.clear();
for (const auto &it : heartbeat_resp_message.servers_meta()) { for (const auto &it : heartbeat_resp_message.servers_meta()) {
NodeInfo info; NodeInfo info;
info.ip_ = it.ip(); info.ip_ = it.ip();
info.node_id_ = it.node_id(); info.node_id_ = it.node_id();
info.port_ = static_cast<uint16_t>(it.port()); info.port_ = it.port();
info.node_role_ = it.role(); info.node_role_ = it.role();
info.rank_id_ = it.rank_id(); info.rank_id_ = it.rank_id();
info.is_alive = it.is_alive(); info.is_alive = it.is_alive();
if (!info.is_alive) {
timeoutNodeId += (info.node_id_ + " ");
}
all_nodes_info_[info.node_id_] = info; all_nodes_info_[info.node_id_] = info;
MS_LOG(DEBUG) << "The node id:" << info.node_id_ << ", the rank id:" << info.rank_id_ MS_LOG(DEBUG) << "The node id:" << info.node_id_ << ", the rank id:" << info.rank_id_
<< ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive; << ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive;
@ -608,7 +607,8 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
wait_start_cond_.notify_all(); wait_start_cond_.notify_all();
OnEventCallback(ClusterEvent::NODE_TIMEOUT); OnEventCallback(ClusterEvent::NODE_TIMEOUT);
} else { } else {
MS_LOG(INFO) << "The node is support recovery, users can pull up this node to restore the cluster."; MS_LOG(INFO) << "The nodes:" << timeoutNodeId
<< "is support recovery, users can pull up this node to restore the cluster.";
} }
} }
} }
@ -653,14 +653,14 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
return; return;
} }
SendMetadataMessage send_meta_message; SendMetadataMessage send_meta_message;
CHECK_RETURN_TYPE(send_meta_message.ParseFromArray(data, SizeToInt(size))); send_meta_message.ParseFromArray(data, SizeToInt(size));
worker_num_ = send_meta_message.worker_num(); worker_num_ = send_meta_message.worker_num();
server_num_ = send_meta_message.server_num(); server_num_ = send_meta_message.server_num();
if (send_meta_message.rank_id() < 0) { if (send_meta_message.rank_id() < 0) {
MS_LOG(EXCEPTION) << "The rank id is wrong."; MS_LOG(EXCEPTION) << "The rank id is wrong.";
} }
node_info_.rank_id_ = send_meta_message.rank_id(); node_info_.rank_id_ = send_meta_message.rank_id();
current_cluster_state_ = send_meta_message.cluster_state(); UpdateClusterState(send_meta_message.cluster_state());
MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_ MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_
<< ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_) << ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_)
<< ", the rank id:" << node_info_.rank_id_; << ", the rank id:" << node_info_.rank_id_;
@ -669,7 +669,8 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
nodes_address_.clear(); nodes_address_.clear();
for (const auto &it : send_meta_message.servers_meta()) { for (const auto &it : send_meta_message.servers_meta()) {
nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port()); nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port());
MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port() << ", the rank id:" << it.rank_id(); MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(it.role()) << ", node id:" << it.node_id()
<< ", rank id:" << it.rank_id() << ", ip:" << it.ip() << ", port:" << it.port();
} }
client_mutex_.unlock(); client_mutex_.unlock();
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) { if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
@ -690,6 +691,7 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
std::lock_guard<std::mutex> lock(client_mutex_); std::lock_guard<std::mutex> lock(client_mutex_);
connected_nodes_.clear(); connected_nodes_.clear();
PersistMetaData();
} }
void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
@ -710,11 +712,13 @@ void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &con
MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta); MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
MS_LOG(INFO) << "This node receive a scale out done from scheduler.";
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) { if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
MS_LOG(WARNING) << "Server response message failed."; MS_LOG(WARNING) << "Server response message failed.";
} }
is_ready_ = true; is_ready_ = true;
current_cluster_state_ = ClusterState::CLUSTER_READY; UpdateClusterState(ClusterState::CLUSTER_READY);
PersistMetaData();
} }
void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
@ -727,7 +731,8 @@ void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn
MS_LOG(WARNING) << "Server response message failed."; MS_LOG(WARNING) << "Server response message failed.";
} }
is_ready_ = true; is_ready_ = true;
current_cluster_state_ = ClusterState::CLUSTER_READY; UpdateClusterState(ClusterState::CLUSTER_READY);
PersistMetaData();
} }
void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
@ -736,13 +741,18 @@ void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, cons
MS_EXCEPTION_IF_NULL(meta); MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
EventRespMessage event_resp_message; EventRespMessage event_resp_message;
CHECK_RETURN_TYPE(event_resp_message.ParseFromArray(data, SizeToInt(size))); event_resp_message.ParseFromArray(data, SizeToInt(size));
uint32_t event = event_resp_message.event(); uint32_t event = event_resp_message.event();
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) { if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
MS_LOG(WARNING) << "Server response message failed."; MS_LOG(WARNING) << "Server response message failed.";
} }
MS_LOG(INFO) << "This node receive a event:" << event;
if (event == static_cast<uint32_t>(ps::UserDefineEvent::kNodeTimeout)) {
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
} else {
OnCustomEventCallback(event); OnCustomEventCallback(event);
} }
}
void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &, const void *data, size_t size) { const Protos &, const void *data, size_t size) {
@ -751,7 +761,7 @@ void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, c
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
ScaleOutMessage scale_out_message; ScaleOutMessage scale_out_message;
CHECK_RETURN_TYPE(scale_out_message.ParseFromArray(data, SizeToInt(size))); scale_out_message.ParseFromArray(data, SizeToInt(size));
int32_t worker_num = scale_out_message.worker_num(); int32_t worker_num = scale_out_message.worker_num();
int32_t server_num = scale_out_message.server_num(); int32_t server_num = scale_out_message.server_num();
MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num; MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num;
@ -760,7 +770,7 @@ void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, c
MS_LOG(WARNING) << "Server response message failed."; MS_LOG(WARNING) << "Server response message failed.";
} }
OnEventCallback(ClusterEvent::READY_FOR_SCALE_OUT); OnEventCallback(ClusterEvent::READY_FOR_SCALE_OUT);
current_cluster_state_ = ClusterState::CLUSTER_SCALE_OUT; UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
is_ready_ = false; is_ready_ = false;
} }
@ -771,7 +781,7 @@ void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, co
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
ScaleInMessage scale_in_message; ScaleInMessage scale_in_message;
CHECK_RETURN_TYPE(scale_in_message.ParseFromArray(data, SizeToInt(size))); scale_in_message.ParseFromArray(data, SizeToInt(size));
int32_t worker_num = scale_in_message.worker_num(); int32_t worker_num = scale_in_message.worker_num();
int32_t server_num = scale_in_message.server_num(); int32_t server_num = scale_in_message.server_num();
MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num; MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num;
@ -789,7 +799,44 @@ void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, co
MS_LOG(WARNING) << "Server response message failed."; MS_LOG(WARNING) << "Server response message failed.";
} }
OnEventCallback(ClusterEvent::READY_FOR_SCALE_IN); OnEventCallback(ClusterEvent::READY_FOR_SCALE_IN);
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN; UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
is_ready_ = false;
}
void AbstractNode::ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
size_t size) {
MS_EXCEPTION_IF_NULL(conn);
MS_EXCEPTION_IF_NULL(meta);
MS_EXCEPTION_IF_NULL(data);
if (is_connected_to_scheduler_.load()) {
MS_LOG(WARNING) << "This node has been connected to scheduler.";
return;
}
SendMetadataMessage scheduler_recovery_message;
(void)scheduler_recovery_message.ParseFromArray(data, SizeToInt(size));
worker_num_ = scheduler_recovery_message.worker_num();
server_num_ = scheduler_recovery_message.server_num();
uint32_t rank_id = scheduler_recovery_message.rank_id();
MS_LOG(INFO) << "[Scheduler Recovery]: The scheduler recovery worker num:" << worker_num_
<< ", the server num:" << server_num_ << ", the rank id: " << rank_id;
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
MS_LOG(WARNING) << "[Scheduler Recovery]: Server response message failed.";
}
MS_LOG(INFO) << "[Scheduler Recovery]: Server response message success!.";
if (!InitClientToScheduler()) {
MS_LOG(WARNING) << "[Scheduler Recovery]: Server node connect to scheduler timedout!";
}
Register(client_to_scheduler_);
std::lock_guard<std::mutex> lock(client_mutex_);
connected_nodes_.clear();
MS_LOG(INFO) << "[Scheduler Recovery]: This node connect to scheduler successful!";
UpdateClusterState(ClusterState::CLUSTER_SCHEDULER_RECOVERY);
is_ready_ = false; is_ready_ = false;
} }
@ -819,6 +866,14 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
return res; return res;
} }
void AbstractNode::InitClientToServer() {
// create tcp client to myself in case of event dispatch failed when Send msg to server 0 failed
client_to_server_ = std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, config_.get());
MS_EXCEPTION_IF_NULL(client_to_server_);
client_to_server_->Init();
MS_LOG(INFO) << "The node start a tcp client to this node!";
}
bool AbstractNode::InitClientToScheduler() { bool AbstractNode::InitClientToScheduler() {
if (config_ == nullptr) { if (config_ == nullptr) {
MS_LOG(WARNING) << "The config is empty."; MS_LOG(WARNING) << "The config is empty.";
@ -843,7 +898,6 @@ bool AbstractNode::InitClientToScheduler() {
MsException::Instance().SetException(); MsException::Instance().SetException();
} }
}); });
client_to_scheduler_->Init(); client_to_scheduler_->Init();
client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() { client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
MS_LOG(INFO) << "The node start a tcp client!"; MS_LOG(INFO) << "The node start a tcp client!";
@ -851,11 +905,14 @@ bool AbstractNode::InitClientToScheduler() {
}); });
client_to_scheduler_thread_->detach(); client_to_scheduler_thread_->detach();
client_to_scheduler_->set_connected_callback([&]() { is_connected_to_scheduler_ = true; });
client_to_scheduler_->set_disconnected_callback([&]() { client_to_scheduler_->set_disconnected_callback([&]() {
std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval)); std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval));
if (is_ready_.load() == false) { if (is_ready_.load() == false) {
client_to_scheduler_->Init(); client_to_scheduler_->Init();
} }
is_connected_to_scheduler_ = false;
}); });
bool wait_res = client_to_scheduler_->WaitConnected(); bool wait_res = client_to_scheduler_->WaitConnected();
if (!wait_res) { if (!wait_res) {
@ -892,6 +949,9 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
case NodeCommand::COLLECTIVE_SEND_DATA: case NodeCommand::COLLECTIVE_SEND_DATA:
MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!"; MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
break; break;
case NodeCommand::SEND_EVENT:
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a send_event command message response!";
break;
default: default:
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
} }
@ -964,8 +1024,9 @@ void AbstractNode::ProcessSendData(const std::shared_ptr<TcpConnection> &conn, c
if (size > 0) { if (size > 0) {
size_t dest_size = size; size_t dest_size = size;
size_t src_size = size; size_t src_size = size;
if (memcpy_s(res.get(), dest_size, data, src_size) != EOK) { auto ret = memcpy_s(res.get(), dest_size, data, src_size);
MS_LOG(EXCEPTION) << "The memcpy_s error"; if (ret != EOK) {
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
} }
} }
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
@ -1066,6 +1127,7 @@ void AbstractNode::InitServerHandler() {
server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone; server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone;
server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone; server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone;
server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent; server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent;
server_handler_[NodeCommand::SCHEDULER_RECOVERY] = &AbstractNode::ProcessSchedulerRecovery;
} }
void AbstractNode::InitNodeInfo(const NodeRole &role) { void AbstractNode::InitNodeInfo(const NodeRole &role) {
@ -1090,8 +1152,8 @@ void AbstractNode::InitNodeInfo(const NodeRole &role) {
} }
void AbstractNode::InitNodeNum() { void AbstractNode::InitNodeNum() {
worker_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_worker_num); worker_num_ = UintToInt(PSContext::instance()->cluster_config().initial_worker_num);
server_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_server_num); server_num_ = UintToInt(PSContext::instance()->cluster_config().initial_server_num);
scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host; scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host;
scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port; scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port;
MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_ MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_
@ -1104,7 +1166,10 @@ bool AbstractNode::Recover() {
MS_LOG(INFO) << "The node is support recovery."; MS_LOG(INFO) << "The node is support recovery.";
node_recovery_ = std::make_unique<NodeRecovery>(this); node_recovery_ = std::make_unique<NodeRecovery>(this);
MS_EXCEPTION_IF_NULL(node_recovery_); MS_EXCEPTION_IF_NULL(node_recovery_);
node_recovery_->Initialize(config_->Get(kKeyRecovery, "")); if (!node_recovery_->Initialize(config_->Get(kKeyRecovery, ""))) {
MS_LOG(ERROR) << "Initializing node recovery failed.";
return false;
}
return node_recovery_->Recover(); return node_recovery_->Recover();
} }
return false; return false;
@ -1132,7 +1197,7 @@ void AbstractNode::OnCustomEventCallback(const uint32_t &event) {
} }
} }
bool AbstractNode::IsWorkerOrServer0(const mindspore::HashMap<std::string, NodeInfo> &info) { bool AbstractNode::IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info) {
for (const auto &it : info) { for (const auto &it : info) {
if (it.second.is_alive == true && it.second.node_role_ == NodeRole::WORKER) { if (it.second.is_alive == true && it.second.node_role_ == NodeRole::WORKER) {
return true; return true;
@ -1149,7 +1214,12 @@ void AbstractNode::CreateTcpServer() {
MS_EXCEPTION_IF_NULL(config_); MS_EXCEPTION_IF_NULL(config_);
std::string interface; std::string interface;
std::string server_ip; std::string server_ip;
if (ps::PSContext::instance()->server_mode().empty()) {
// If the server mode is not set, use 127.0.0.1 as server ip address for distributed learning.
server_ip = "127.0.0.1";
} else {
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
}
server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get()); server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get());
MS_EXCEPTION_IF_NULL(server_); MS_EXCEPTION_IF_NULL(server_);
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
@ -1179,6 +1249,29 @@ void AbstractNode::CreateTcpServer() {
MS_EXCEPTION_IF_NULL(server_thread_); MS_EXCEPTION_IF_NULL(server_thread_);
server_thread_->detach(); server_thread_->detach();
} }
void AbstractNode::UpdateClusterState(const ClusterState &state) {
std::lock_guard<std::mutex> lk(cluster_state_mutex_);
MS_LOG(INFO) << "[state]: Cluster state change from:" << CommUtil::ClusterStateToString(current_cluster_state_)
<< " to " << CommUtil::ClusterStateToString(state);
current_cluster_state_ = state;
}
void AbstractNode::PersistMetaData() {
if (node_recovery_ == nullptr) {
MS_LOG(WARNING) << "node recovery is null, so don't persist meta data";
return;
}
if (config_->Exists(kKeyRecovery)) {
ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
clusterConfig.scheduler_host = this->scheduler_ip();
clusterConfig.scheduler_port = this->scheduler_port();
clusterConfig.initial_worker_num = worker_num_;
clusterConfig.initial_server_num = server_num_;
node_recovery_->Persist(clusterConfig);
}
}
} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -22,8 +22,8 @@
#include <memory> #include <memory>
#include <map> #include <map>
#include <vector> #include <vector>
#include <unordered_map>
#include "utils/hash_map.h"
#include "ps/core/node.h" #include "ps/core/node.h"
#include "ps/core/communicator/message.h" #include "ps/core/communicator/message.h"
#include "ps/core/follower_scaler.h" #include "ps/core/follower_scaler.h"
@ -44,10 +44,12 @@ class AbstractNode : public Node {
: heart_beat_thread_(nullptr), : heart_beat_thread_(nullptr),
client_to_scheduler_thread_(nullptr), client_to_scheduler_thread_(nullptr),
client_to_scheduler_(nullptr), client_to_scheduler_(nullptr),
client_to_server_(nullptr),
server_(nullptr), server_(nullptr),
server_thread_(nullptr), server_thread_(nullptr),
worker_num_(-1), worker_num_(-1),
server_num_(-1), server_num_(-1),
is_connected_to_scheduler_(false),
is_current_node_scale_in_(false), is_current_node_scale_in_(false),
follower_scaler_(nullptr), follower_scaler_(nullptr),
node_recovery_(nullptr), node_recovery_(nullptr),
@ -94,14 +96,14 @@ class AbstractNode : public Node {
void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb); void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb);
bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command, bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
const uint32_t &timeout = kTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data, bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
const std::vector<size_t> &lens, int command, const uint32_t &timeout = kTimeoutInSeconds); const std::vector<size_t> &lens, int command, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command, bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command,
VectorPtr *output, const uint32_t &timeout = kTimeoutInSeconds); VectorPtr *output, const uint32_t &timeout = kCommTimeoutInSeconds);
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data, bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output, const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
const uint32_t &timeout = kTimeoutInSeconds); const uint32_t &timeout = kCommTimeoutInSeconds);
uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size); uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id, std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
@ -170,6 +172,10 @@ class AbstractNode : public Node {
void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size); const Protos &protos, const void *data, size_t size);
// The worker/server processes the scheduler recovery message from scheduelr
void ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &, const void *data, size_t size);
// The worker/server processes the SEND_EVENT message from scheduelr // The worker/server processes the SEND_EVENT message from scheduelr
void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta, void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size); const Protos &protos, const void *data, size_t size);
@ -180,6 +186,7 @@ class AbstractNode : public Node {
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout); bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
bool WaitForDisconnect(const uint32_t &timeout); bool WaitForDisconnect(const uint32_t &timeout);
bool InitClientToScheduler(); bool InitClientToScheduler();
void InitClientToServer();
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id, const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id,
const NodeRole &role = NodeRole::SERVER); const NodeRole &role = NodeRole::SERVER);
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message, bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
@ -212,14 +219,18 @@ class AbstractNode : public Node {
// Trigger the callback corresponding to the custom event. // Trigger the callback corresponding to the custom event.
void OnCustomEventCallback(const uint32_t &event); void OnCustomEventCallback(const uint32_t &event);
bool IsWorkerOrServer0(const mindspore::HashMap<std::string, NodeInfo> &info); bool IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info);
void CreateTcpServer(); void CreateTcpServer();
void UpdateClusterState(const ClusterState &state);
void PersistMetaData();
std::unique_ptr<std::thread> heart_beat_thread_; std::unique_ptr<std::thread> heart_beat_thread_;
std::unique_ptr<std::thread> client_to_scheduler_thread_; std::unique_ptr<std::thread> client_to_scheduler_thread_;
std::shared_ptr<TcpClient> client_to_scheduler_; std::shared_ptr<TcpClient> client_to_scheduler_;
std::shared_ptr<TcpClient> client_to_server_;
// the key is: <node_role,rank_id>, the value is: <ip, port> // the key is: <node_role,rank_id>, the value is: <ip, port>
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_; std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
// the map's key is: rank_id // the map's key is: rank_id
@ -233,13 +244,13 @@ class AbstractNode : public Node {
std::condition_variable receive_cond_; std::condition_variable receive_cond_;
// the key is rank_id, the value is rank_id's expected request_id // the key is rank_id, the value is rank_id's expected request_id
mindspore::HashMap<uint32_t, uint64_t> expected_rank_request_ids_; std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_;
// the key is rank_id, the value is rank_id's actual request_id // the key is rank_id, the value is rank_id's actual request_id
mindspore::HashMap<uint32_t, uint64_t> actual_rank_request_ids_; std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
std::mutex rank_request_ids_mutex; std::mutex rank_request_ids_mutex;
timeval scheduler_time_{0, 0}; timeval scheduler_time_{0, 0};
mindspore::HashMap<NodeCommand, ResponseHandler> handlers_; std::unordered_map<NodeCommand, ResponseHandler> handlers_;
mindspore::HashMap<NodeCommand, ServerHandler> server_handler_; std::unordered_map<NodeCommand, ServerHandler> server_handler_;
// Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA // Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA
std::shared_ptr<TcpServer> server_; std::shared_ptr<TcpServer> server_;
@ -247,7 +258,7 @@ class AbstractNode : public Node {
int32_t worker_num_; int32_t worker_num_;
int32_t server_num_; int32_t server_num_;
std::atomic<bool> is_connected_to_scheduler_;
// Identify whether the current node is a scale in node. // Identify whether the current node is a scale in node.
std::atomic<bool> is_current_node_scale_in_; std::atomic<bool> is_current_node_scale_in_;
@ -273,11 +284,12 @@ class AbstractNode : public Node {
uint16_t scheduler_port_; uint16_t scheduler_port_;
// Synchronize all node metadata from the scheduler. // Synchronize all node metadata from the scheduler.
mindspore::HashMap<std::string, NodeInfo> all_nodes_info_; std::unordered_map<std::string, NodeInfo> all_nodes_info_;
RequestHandler request_handler_; RequestHandler request_handler_;
mindspore::HashMap<std::string, std::shared_ptr<CommunicatorBase>> communicators_; std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_;
std::mutex communicator_mutex_; std::mutex communicator_mutex_;
std::mutex cluster_state_mutex_;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps

View File

@ -21,8 +21,10 @@
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <unordered_map>
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "ps/core/node_info.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
@ -38,10 +40,12 @@ struct ClusterConfig {
scheduler_host(host), scheduler_host(host),
scheduler_port(port), scheduler_port(port),
heartbeat_timeout(30), heartbeat_timeout(30),
cluster_available_timeout(300), cluster_available_timeout(900),
connect_interval(3000), connect_interval(3000),
scheduler_timeout(30) {} scheduler_timeout(30),
initial_total_node_num(0),
initial_next_worker_rank_id(0),
initial_next_server_rank_id(0) {}
// Configure through environment variables:MS_WORKER_NUM // Configure through environment variables:MS_WORKER_NUM
uint32_t initial_worker_num; uint32_t initial_worker_num;
// Configure through environment variables:MS_SERVER_NUM // Configure through environment variables:MS_SERVER_NUM
@ -59,6 +63,11 @@ struct ClusterConfig {
uint32_t connect_interval; uint32_t connect_interval;
// When the scheduler exits, the worker and server can continue to work for 5 hours // When the scheduler exits, the worker and server can continue to work for 5 hours
int64_t scheduler_timeout; int64_t scheduler_timeout;
// the node that has bean registered to scheduler
std::unordered_map<std::string, NodeInfo> initial_registered_nodes_infos;
uint32_t initial_total_node_num;
uint32_t initial_next_worker_rank_id;
uint32_t initial_next_server_rank_id;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps

View File

@ -32,7 +32,6 @@ namespace core {
*/ */
struct ClusterMetadata { struct ClusterMetadata {
ClusterMetadata(const uint32_t &worker, const uint32_t &server) : worker_num(worker), server_num(server) {} ClusterMetadata(const uint32_t &worker, const uint32_t &server) : worker_num(worker), server_num(server) {}
uint32_t worker_num; uint32_t worker_num;
uint32_t server_num; uint32_t server_num;
}; };

View File

@ -96,6 +96,28 @@ void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *i
freeifaddrs(if_address); freeifaddrs(if_address);
} }
std::string CommUtil::GetLoopBackInterfaceName() {
struct ifaddrs *if_address = nullptr;
struct ifaddrs *ifa = nullptr;
if (getifaddrs(&if_address) == -1) {
MS_LOG(WARNING) << "Get ifaddrs failed.";
}
for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == nullptr) {
continue;
}
if (ifa->ifa_flags & IFF_LOOPBACK) {
MS_LOG(INFO) << "Loop back interface name is " << ifa->ifa_name;
return ifa->ifa_name;
}
}
MS_EXCEPTION_IF_NULL(if_address);
freeifaddrs(if_address);
return "";
}
std::string CommUtil::GenerateUUID() { std::string CommUtil::GenerateUUID() {
std::stringstream ss; std::stringstream ss;
int i; int i;
@ -135,6 +157,18 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) {
MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!"; MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
} }
} }
NodeRole CommUtil::StringToNodeRole(const std::string &roleStr) {
if (roleStr == "SCHEDULER") {
return NodeRole::SCHEDULER;
} else if (roleStr == "SERVER") {
return NodeRole::SERVER;
} else if (roleStr == "WORKER") {
return NodeRole::WORKER;
} else {
MS_LOG(EXCEPTION) << "The node role string:" << roleStr << " is illegal!";
}
}
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num, bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
const int32_t &total_server_num) { const int32_t &total_server_num) {
if (node_role == NodeRole::SERVER && (rank_id > IntToUint(total_server_num) - 1)) { if (node_role == NodeRole::SERVER && (rank_id > IntToUint(total_server_num) - 1)) {
@ -183,7 +217,7 @@ bool CommUtil::IsFileExists(const std::string &file) {
} }
std::string CommUtil::ClusterStateToString(const ClusterState &state) { std::string CommUtil::ClusterStateToString(const ClusterState &state) {
MS_LOG(INFO) << "The cluster state:" << state; MS_LOG(DEBUG) << "The cluster state:" << state;
if (state < SizeToInt(kClusterState.size())) { if (state < SizeToInt(kClusterState.size())) {
return kClusterState.at(state); return kClusterState.at(state);
} else { } else {
@ -191,21 +225,50 @@ std::string CommUtil::ClusterStateToString(const ClusterState &state) {
} }
} }
std::string CommUtil::ParseConfig(const Configuration &config, const std::string &data) { std::string CommUtil::ParseConfig(const Configuration &config, const std::string &key) {
if (!config.IsInitialized()) { if (!config.IsInitialized()) {
MS_LOG(INFO) << "The config is not initialized."; MS_LOG(INFO) << "The config is not initialized.";
return ""; return "";
} }
if (!const_cast<Configuration &>(config).Exists(data)) { if (!const_cast<Configuration &>(config).Exists(key)) {
MS_LOG(INFO) << "The data:" << data << " is not exist."; MS_LOG(INFO) << "The key:" << key << " is not exist.";
return ""; return "";
} }
std::string path = config.GetString(data, ""); std::string path = config.GetString(key, "");
return path; return path;
} }
bool CommUtil::verifyCertTimeStamp(const X509 *cert) {
ASN1_TIME *start = X509_getm_notBefore(cert);
ASN1_TIME *end = X509_getm_notAfter(cert);
int day = 0;
int sec = 0;
int ret = ASN1_TIME_diff(&day, &sec, start, NULL);
if (ret != 1) {
return false;
}
if (day < 0 || sec < 0) {
MS_LOG(ERROR) << "cert start time is later than now time.";
return false;
}
day = 0;
sec = 0;
ret = ASN1_TIME_diff(&day, &sec, NULL, end);
if (ret != 1) {
return false;
}
if (day < 0 || sec < 0) {
MS_LOG(ERROR) << "cert end time is sooner than now time.";
return false;
}
return true;
}
bool CommUtil::VerifyCertTime(const X509 *cert, int64_t time) { bool CommUtil::VerifyCertTime(const X509 *cert, int64_t time) {
MS_EXCEPTION_IF_NULL(cert); MS_EXCEPTION_IF_NULL(cert);
ASN1_TIME *start = X509_getm_notBefore(cert); ASN1_TIME *start = X509_getm_notBefore(cert);
@ -249,61 +312,55 @@ bool CommUtil::VerifyCRL(const X509 *cert, const std::string &crl_path) {
MS_ERROR_IF_NULL_W_RET_VAL(cert, false); MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
BIO *bio = BIO_new_file(crl_path.c_str(), "r"); BIO *bio = BIO_new_file(crl_path.c_str(), "r");
MS_ERROR_IF_NULL_W_RET_VAL(bio, false); MS_ERROR_IF_NULL_W_RET_VAL(bio, false);
X509_CRL *root_crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
X509_CRL *root_crl = nullptr;
EVP_PKEY *evp_pkey = nullptr;
bool result = true;
do {
root_crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
MS_ERROR_IF_NULL_W_RET_VAL(root_crl, false); MS_ERROR_IF_NULL_W_RET_VAL(root_crl, false);
EVP_PKEY *evp_pkey = X509_get_pubkey(const_cast<X509 *>(cert)); evp_pkey = X509_get_pubkey(const_cast<X509 *>(cert));
MS_ERROR_IF_NULL_W_RET_VAL(evp_pkey, false); MS_ERROR_IF_NULL_W_RET_VAL(evp_pkey, false);
int ret = X509_CRL_verify(root_crl, evp_pkey); int ret = X509_CRL_verify(root_crl, evp_pkey);
BIO_free_all(bio);
if (ret == 1) { if (ret == 1) {
MS_LOG(WARNING) << "Equip cert in root crl, verify failed"; MS_LOG(WARNING) << "Equip cert in root crl, verify failed";
return false; result = false;
break;
} }
} while (0);
BIO_free_all(bio);
EVP_PKEY_free(evp_pkey);
X509_CRL_free(root_crl);
MS_LOG(INFO) << "VerifyCRL success."; MS_LOG(INFO) << "VerifyCRL success.";
return true; return result;
} }
bool CommUtil::VerifyCommonName(const X509 *cert, const std::string &ca_path) { bool CommUtil::VerifyCommonName(const X509 *caCert, const X509 *subCert) {
MS_ERROR_IF_NULL_W_RET_VAL(cert, false); MS_EXCEPTION_IF_NULL(caCert);
X509 *cert_temp = const_cast<X509 *>(cert); MS_EXCEPTION_IF_NULL(subCert);
char subject_cn[256] = ""; char caSubjectCN[256] = "";
char issuer_cn[256] = ""; char subIssuerCN[256] = "";
X509_NAME *subject_name = X509_get_subject_name(cert_temp);
X509_NAME *issuer_name = X509_get_issuer_name(cert_temp);
MS_ERROR_IF_NULL_W_RET_VAL(subject_name, false);
MS_ERROR_IF_NULL_W_RET_VAL(issuer_name, false);
if (!X509_NAME_get_text_by_NID(subject_name, NID_commonName, subject_cn, sizeof(subject_cn))) {
MS_LOG(WARNING) << "Get text by nid failed.";
return false;
}
if (!X509_NAME_get_text_by_NID(issuer_name, NID_commonName, issuer_cn, sizeof(issuer_cn))) {
MS_LOG(WARNING) << "Get text by nid failed.";
return false;
}
MS_LOG(INFO) << "the subject:" << subject_cn << ", the issuer:" << issuer_cn;
BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r"); X509_NAME *caSubjectX509CN = X509_get_subject_name(caCert);
MS_EXCEPTION_IF_NULL(ca_bio); X509_NAME *subIssuerX509CN = X509_get_issuer_name(subCert);
X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
MS_EXCEPTION_IF_NULL(ca_cert); int ret = X509_NAME_get_text_by_NID(caSubjectX509CN, NID_commonName, caSubjectCN, sizeof(caSubjectCN));
char ca_subject_cn[256] = ""; if (ret < 0) {
char ca_issuer_cn[256] = "";
X509_NAME *ca_subject_name = X509_get_subject_name(ca_cert);
X509_NAME *ca_issuer_name = X509_get_issuer_name(ca_cert);
MS_ERROR_IF_NULL_W_RET_VAL(ca_subject_name, false);
MS_ERROR_IF_NULL_W_RET_VAL(ca_issuer_name, false);
if (!X509_NAME_get_text_by_NID(ca_subject_name, NID_commonName, ca_subject_cn, sizeof(subject_cn))) {
MS_LOG(WARNING) << "Get text by nid failed.";
return false; return false;
} }
if (!X509_NAME_get_text_by_NID(ca_issuer_name, NID_commonName, ca_issuer_cn, sizeof(issuer_cn))) { ret = X509_NAME_get_text_by_NID(subIssuerX509CN, NID_commonName, subIssuerCN, sizeof(subIssuerCN));
MS_LOG(WARNING) << "Get text by nid failed."; if (ret < 0) {
return false; return false;
} }
MS_LOG(INFO) << "the subject:" << ca_subject_cn << ", the issuer:" << ca_issuer_cn;
BIO_free_all(ca_bio); std::string caSubjectCNStr = caSubjectCN;
if (strcmp(issuer_cn, ca_subject_cn) != 0) { std::string subIssuerCNStr = subIssuerCN;
if (caSubjectCNStr != subIssuerCNStr) {
MS_LOG(EXCEPTION) << "root CA cert subject cn is not equal with equip CA cert issuer cn.";
return false; return false;
} }
return true; return true;
@ -330,19 +387,170 @@ bool CommUtil::VerifyCipherList(const std::vector<std::string> &list) {
return true; return true;
} }
void CommUtil::InitOpenSSLEnv() { bool CommUtil::verifyCertKeyID(const X509 *caCert, const X509 *subCert) {
if (!SSL_library_init()) { MS_EXCEPTION_IF_NULL(caCert);
MS_LOG(EXCEPTION) << "SSL_library_init failed."; MS_EXCEPTION_IF_NULL(subCert);
int crit = 0;
ASN1_OCTET_STRING *skid =
reinterpret_cast<ASN1_OCTET_STRING *>(X509_get_ext_d2i(caCert, NID_subject_key_identifier, &crit, NULL));
MS_EXCEPTION_IF_NULL(skid);
const int keyidLen = 512;
char subject_keyid[keyidLen] = {0};
for (int i = 0; i < skid->length; i++) {
char keyid[8] = {0};
int base = keyidLen;
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)skid->data[i]);
int ret = strcat_s(subject_keyid, base, keyid);
if (ret == -1) {
return false;
} }
if (!ERR_load_crypto_strings()) {
MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
} }
if (!SSL_load_error_strings()) {
MS_LOG(EXCEPTION) << "SSL_load_error_strings failed."; AUTHORITY_KEYID *akeyid =
reinterpret_cast<AUTHORITY_KEYID *>(X509_get_ext_d2i(subCert, NID_authority_key_identifier, &crit, NULL));
MS_EXCEPTION_IF_NULL(akeyid);
MS_EXCEPTION_IF_NULL(akeyid->keyid);
char issuer_keyid[keyidLen] = {0};
for (int i = 0; i < akeyid->keyid->length; i++) {
char keyid[8] = {0};
int base = keyidLen;
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)(akeyid->keyid->data[i]));
int ret = strcat_s(issuer_keyid, base, keyid);
if (ret == -1) {
return false;
} }
if (!OpenSSL_add_all_algorithms()) {
MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
} }
std::string subject_keyid_str = subject_keyid;
std::string issuer_keyid_str = issuer_keyid;
if (subject_keyid_str != issuer_keyid_str) {
return false;
}
return true;
}
bool CommUtil::verifySingature(const X509 *caCert, const X509 *subCert) {
MS_EXCEPTION_IF_NULL(caCert);
MS_EXCEPTION_IF_NULL(subCert);
EVP_PKEY *caCertPubKey = X509_get_pubkey(const_cast<X509 *>(caCert));
int ret = 0;
ret = X509_verify(const_cast<X509 *>(subCert), caCertPubKey);
if (ret != 1) {
EVP_PKEY_free(caCertPubKey);
MS_LOG(ERROR) << "sub cert verify is failed";
return false;
}
MS_LOG(INFO) << "verifyCAChain success.";
EVP_PKEY_free(caCertPubKey);
return true;
}
bool CommUtil::verifyExtendedAttributes(const X509 *caCert) {
MS_EXCEPTION_IF_NULL(caCert);
int cirt = 0;
BASIC_CONSTRAINTS *bcons =
reinterpret_cast<BASIC_CONSTRAINTS *>(X509_get_ext_d2i(caCert, NID_basic_constraints, &cirt, NULL));
if (bcons == nullptr) {
return false;
}
if (!bcons->ca) {
MS_LOG(ERROR) << "Subject Type is End Entity.";
return false;
}
MS_LOG(INFO) << "Subject Type is CA.";
return true;
}
void CommUtil::verifyCertPipeline(const X509 *caCert, const X509 *subCert) {
if (!CommUtil::VerifyCommonName(caCert, subCert)) {
MS_LOG(EXCEPTION) << "Verify common name failed.";
}
if (!CommUtil::verifySingature(caCert, subCert)) {
MS_LOG(EXCEPTION) << "Verify Singature failed.";
}
if (!CommUtil::verifyExtendedAttributes(caCert)) {
MS_LOG(EXCEPTION) << "Verify Extended Attributes failed.";
}
if (!CommUtil::verifyCertKeyID(caCert, subCert)) {
MS_LOG(EXCEPTION) << "Verify Cert KeyID failed.";
}
if (!CommUtil::verifyCertTimeStamp(caCert) || !CommUtil::verifyCertTimeStamp(subCert)) {
MS_LOG(EXCEPTION) << "Verify Cert Time failed.";
}
}
bool CommUtil::checkCRLTime(const std::string &crlPath) {
if (!IsFileExists(crlPath)) {
return true;
}
BIO *bio = BIO_new_file(crlPath.c_str(), "r");
if (bio == nullptr) {
return true;
}
bool result = true;
X509_CRL *crl = nullptr;
do {
crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
if (crl == nullptr) {
MS_LOG(WARNING) << "crl is nullptr. return true.";
result = true;
break;
}
const ASN1_TIME *lastUpdate = X509_CRL_get0_lastUpdate(crl);
const ASN1_TIME *nextUpdate = X509_CRL_get0_nextUpdate(crl);
int day = 0;
int sec = 0;
int ret = ASN1_TIME_diff(&day, &sec, lastUpdate, NULL);
if (ret != 1) {
result = false;
break;
}
if (day < 0 || sec < 0) {
MS_LOG(ERROR) << "crl start time is later than now time.";
result = false;
break;
}
day = 0;
sec = 0;
ret = ASN1_TIME_diff(&day, &sec, NULL, nextUpdate);
if (ret != 1) {
result = false;
break;
}
if (day < 0 || sec < 0) {
MS_LOG(WARNING) << "crl update time is sooner than now time. please update crl";
}
MS_LOG(INFO) << "verifyCRL time success.";
} while (0);
X509_CRL_free(crl);
BIO_free_all(bio);
return result;
}
std::string CommUtil::BoolToString(bool alive) {
if (alive) {
return "True";
} else {
return "False";
}
}
bool CommUtil::StringToBool(const std::string &alive) {
if (alive == "True") {
return true;
} else if (alive == "False") {
return false;
}
return false;
} }
} // namespace core } // namespace core
} // namespace ps } // namespace ps

View File

@ -44,6 +44,7 @@
#include <assert.h> #include <assert.h>
#include <openssl/pkcs12.h> #include <openssl/pkcs12.h>
#include <openssl/bio.h> #include <openssl/bio.h>
#include <openssl/x509v3.h>
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
@ -99,8 +100,12 @@ class CommUtil {
static bool CheckIp(const std::string &ip); static bool CheckIp(const std::string &ip);
static bool CheckPort(const uint16_t &port); static bool CheckPort(const uint16_t &port);
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
static std::string GetLoopBackInterfaceName();
static std::string GenerateUUID(); static std::string GenerateUUID();
static std::string NodeRoleToString(const NodeRole &role); static std::string NodeRoleToString(const NodeRole &role);
static NodeRole StringToNodeRole(const std::string &roleStr);
static std::string BoolToString(bool alive);
static bool StringToBool(const std::string &alive);
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num, static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
const int32_t &total_server_num); const int32_t &total_server_num);
static bool Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds); static bool Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds);
@ -112,19 +117,21 @@ class CommUtil {
static std::string ClusterStateToString(const ClusterState &state); static std::string ClusterStateToString(const ClusterState &state);
// Parse the configuration file according to the key. // Parse the configuration file according to the key.
static std::string ParseConfig(const Configuration &config, const std::string &data); static std::string ParseConfig(const Configuration &config, const std::string &key);
// verify valid of certificate time // verify valid of certificate time
static bool VerifyCertTime(const X509 *cert, int64_t time = 0); static bool VerifyCertTime(const X509 *cert, int64_t time = 0);
static bool verifyCertTimeStamp(const X509 *cert);
// verify valid of equip certificate with CRL // verify valid of equip certificate with CRL
static bool VerifyCRL(const X509 *cert, const std::string &crl_path); static bool VerifyCRL(const X509 *cert, const std::string &crl_path);
// Check the common name of the certificate static bool VerifyCommonName(const X509 *caCert, const X509 *subCert);
static bool VerifyCommonName(const X509 *cert, const std::string &ca_path);
// The string is divided according to delim
static std::vector<std::string> Split(const std::string &s, char delim); static std::vector<std::string> Split(const std::string &s, char delim);
// Check the cipher list of the certificate
static bool VerifyCipherList(const std::vector<std::string> &list); static bool VerifyCipherList(const std::vector<std::string> &list);
static void InitOpenSSLEnv(); static bool verifyCertKeyID(const X509 *caCert, const X509 *subCert);
static bool verifySingature(const X509 *caCert, const X509 *subCert);
static bool verifyExtendedAttributes(const X509 *caCert);
static void verifyCertPipeline(const X509 *caCert, const X509 *subCert);
static bool checkCRLTime(const std::string &crlPath);
private: private:
static std::random_device rd; static std::random_device rd;

View File

@ -43,7 +43,7 @@ void CommunicatorBase::Join() {
return; return;
} }
bool CommunicatorBase::running() const { return running_; } bool CommunicatorBase::running() { return running_; }
} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

View File

@ -19,10 +19,10 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <unordered_map>
#include <functional> #include <functional>
#include <thread> #include <thread>
#include "utils/hash_map.h"
#include "ps/core/communicator/message_handler.h" #include "ps/core/communicator/message_handler.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "ps/core/communicator/http_message_handler.h" #include "ps/core/communicator/http_message_handler.h"
@ -57,6 +57,7 @@ enum class TcpUserCommand {
kQueryInstance, kQueryInstance,
kEnableFLS, kEnableFLS,
kDisableFLS, kDisableFLS,
kSyncAfterRecover,
kExchangeKeys, kExchangeKeys,
kGetKeys kGetKeys
}; };
@ -84,10 +85,10 @@ class CommunicatorBase {
bool SendResponse(const void *rsp_data, size_t rsp_len, const std::shared_ptr<MessageHandler> &msg_handler); bool SendResponse(const void *rsp_data, size_t rsp_len, const std::shared_ptr<MessageHandler> &msg_handler);
bool running() const; bool running();
protected: protected:
mindspore::HashMap<std::string, MessageCallback> msg_callbacks_; std::unordered_map<std::string, MessageCallback> msg_callbacks_;
std::thread running_thread_; std::thread running_thread_;
bool running_; bool running_;
}; };

View File

@ -22,11 +22,11 @@ namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
bool HttpCommunicator::Start() { bool HttpCommunicator::Start() {
MS_EXCEPTION_IF_NULL(http_server_);
MS_LOG(INFO) << "Initialize http server IP:" << ip_ << ", PORT:" << port_; MS_LOG(INFO) << "Initialize http server IP:" << ip_ << ", PORT:" << port_;
if (!http_server_->InitServer()) { if (!http_server_->InitServer()) {
MS_LOG(EXCEPTION) << "The communicator init http server failed."; MS_LOG(EXCEPTION) << "The communicator init http server failed.";
} }
MS_EXCEPTION_IF_NULL(http_server_);
if (!http_server_->Start()) { if (!http_server_->Start()) {
MS_LOG(EXCEPTION) << "Http server starting failed."; MS_LOG(EXCEPTION) << "Http server starting failed.";
} }
@ -51,9 +51,11 @@ bool HttpCommunicator::Stop() {
} }
void HttpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) { void HttpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) {
MS_LOG(INFO) << "msg_type is: " << msg_type;
msg_callbacks_[msg_type] = cb; msg_callbacks_[msg_type] = cb;
http_msg_callbacks_[msg_type] = std::bind( http_msg_callbacks_[msg_type] = std::bind(
[&](std::shared_ptr<HttpMessageHandler> http_msg) -> void { [&](std::shared_ptr<HttpMessageHandler> http_msg) -> void {
MS_EXCEPTION_IF_NULL(http_msg);
std::shared_ptr<MessageHandler> http_msg_handler = std::make_shared<HttpMsgHandler>(http_msg); std::shared_ptr<MessageHandler> http_msg_handler = std::make_shared<HttpMsgHandler>(http_msg);
MS_EXCEPTION_IF_NULL(http_msg_handler); MS_EXCEPTION_IF_NULL(http_msg_handler);
msg_callbacks_[msg_type](http_msg_handler); msg_callbacks_[msg_type](http_msg_handler);
@ -61,7 +63,8 @@ void HttpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const Me
}, },
std::placeholders::_1); std::placeholders::_1);
std::string url = "/"; std::string url = ps::PSContext::instance()->http_url_prefix();
url += "/";
url += msg_type; url += msg_type;
MS_EXCEPTION_IF_NULL(http_server_); MS_EXCEPTION_IF_NULL(http_server_);
bool is_succeed = http_server_->RegisterRoute(url, &http_msg_callbacks_[msg_type]); bool is_succeed = http_server_->RegisterRoute(url, &http_msg_callbacks_[msg_type]);

View File

@ -19,7 +19,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "utils/hash_map.h" #include <unordered_map>
#include "ps/core/communicator/http_server.h" #include "ps/core/communicator/http_server.h"
#include "ps/core/communicator/http_message_handler.h" #include "ps/core/communicator/http_message_handler.h"
#include "ps/core/communicator/task_executor.h" #include "ps/core/communicator/task_executor.h"
@ -46,7 +46,7 @@ class HttpCommunicator : public CommunicatorBase {
private: private:
std::shared_ptr<TaskExecutor> task_executor_; std::shared_ptr<TaskExecutor> task_executor_;
std::shared_ptr<HttpServer> http_server_; std::shared_ptr<HttpServer> http_server_;
mindspore::HashMap<std::string, HttpMsgCallback> http_msg_callbacks_; std::unordered_map<std::string, HttpMsgCallback> http_msg_callbacks_;
std::string ip_; std::string ip_;
uint16_t port_; uint16_t port_;

View File

@ -265,7 +265,7 @@ void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, co
} }
void HttpMessageHandler::ErrorResponse(int code, const RequestProcessResult &result) { void HttpMessageHandler::ErrorResponse(int code, const RequestProcessResult &result) {
nlohmann::json error_json = {{"error_message", result.StatusMessage()}}; nlohmann::json error_json = {{"error_message", result.StatusMessage()}, {"code", kErrorCode}};
std::string out_error = error_json.dump(); std::string out_error = error_json.dump();
AddRespString(out_error); AddRespString(out_error);
SetRespCode(code); SetRespCode(code);

View File

@ -26,7 +26,7 @@ HttpRequestHandler::~HttpRequestHandler() {
} }
} }
bool HttpRequestHandler::Initialize(int fd, const mindspore::HashMap<std::string, OnRequestReceive *> &handlers) { bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers) {
evbase_ = event_base_new(); evbase_ = event_base_new();
MS_EXCEPTION_IF_NULL(evbase_); MS_EXCEPTION_IF_NULL(evbase_);
struct evhttp *http = evhttp_new(evbase_); struct evhttp *http = evhttp_new(evbase_);
@ -115,8 +115,7 @@ bufferevent *HttpRequestHandler::BuffereventCallback(event_base *base, void *arg
SSL_CTX *ctx = reinterpret_cast<SSL_CTX *>(arg); SSL_CTX *ctx = reinterpret_cast<SSL_CTX *>(arg);
SSL *ssl = SSL_new(ctx); SSL *ssl = SSL_new(ctx);
MS_EXCEPTION_IF_NULL(ssl); MS_EXCEPTION_IF_NULL(ssl);
bufferevent *bev = bufferevent *bev = bufferevent_openssl_socket_new(base, -1, ssl, BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE);
bufferevent_openssl_socket_new(base, -1, ssl, BUFFEREVENT_SSL_ACCEPTING, static_cast<int>(BEV_OPT_CLOSE_ON_FREE));
MS_EXCEPTION_IF_NULL(bev); MS_EXCEPTION_IF_NULL(bev);
return bev; return bev;
} }

View File

@ -25,8 +25,8 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <unordered_map>
#include "utils/hash_map.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "ps/core/communicator/http_message_handler.h" #include "ps/core/communicator/http_message_handler.h"
#include "ps/core/communicator/ssl_http.h" #include "ps/core/communicator/ssl_http.h"
@ -46,7 +46,7 @@ class HttpRequestHandler {
HttpRequestHandler() : evbase_(nullptr) {} HttpRequestHandler() : evbase_(nullptr) {}
virtual ~HttpRequestHandler(); virtual ~HttpRequestHandler();
bool Initialize(int fd, const mindspore::HashMap<std::string, OnRequestReceive *> &handlers); bool Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers);
void Run(); void Run();
bool Stop(); bool Stop();
static bufferevent *BuffereventCallback(event_base *base, void *arg); static bufferevent *BuffereventCallback(event_base *base, void *arg);

View File

@ -68,7 +68,7 @@ bool HttpServer::InitServer() {
return false; return false;
} }
fd_ = ::socket(static_cast<int>(AF_INET), static_cast<int>(SOCK_STREAM), 0); fd_ = ::socket(AF_INET, SOCK_STREAM, 0);
if (fd_ < 0) { if (fd_ < 0) {
MS_LOG(ERROR) << "Socker error!"; MS_LOG(ERROR) << "Socker error!";
return false; return false;
@ -84,7 +84,8 @@ bool HttpServer::InitServer() {
} }
struct sockaddr_in addr; struct sockaddr_in addr;
if (memset_s(&addr, sizeof(addr), 0, sizeof(addr)) != EOK) { errno_t ret = memset_s(&addr, sizeof(addr), 0, sizeof(addr));
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Memset failed."; MS_LOG(EXCEPTION) << "Memset failed.";
} }
@ -132,6 +133,7 @@ bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *functio
if (!function) { if (!function) {
return false; return false;
} }
MS_LOG(INFO) << "request handler url is: " << url;
request_handlers_[url] = function; request_handlers_[url] = function;
return true; return true;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020-2021 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -17,6 +17,8 @@
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_ #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_ #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_
#include "ps/core/communicator/http_message_handler.h"
#include <event2/buffer.h> #include <event2/buffer.h>
#include <event2/event.h> #include <event2/event.h>
#include <event2/http.h> #include <event2/http.h>
@ -34,10 +36,9 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <atomic> #include <atomic>
#include <unordered_map>
#include <vector> #include <vector>
#include "utils/hash_map.h"
#include "ps/core/communicator/http_message_handler.h"
#include "ps/core/communicator/http_request_handler.h" #include "ps/core/communicator/http_request_handler.h"
namespace mindspore { namespace mindspore {
@ -77,7 +78,7 @@ class HttpServer {
std::vector<std::shared_ptr<std::thread>> worker_threads_; std::vector<std::shared_ptr<std::thread>> worker_threads_;
std::vector<std::shared_ptr<HttpRequestHandler>> http_request_handlers; std::vector<std::shared_ptr<HttpRequestHandler>> http_request_handlers;
int32_t backlog_; int32_t backlog_;
mindspore::HashMap<std::string, OnRequestReceive *> request_handlers_; std::unordered_map<std::string, OnRequestReceive *> request_handlers_;
int fd_; int fd_;
}; };
} // namespace core } // namespace core

View File

@ -1,3 +1,4 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
@ -37,7 +38,18 @@ SSLClient::SSLClient() : ssl_ctx_(nullptr), check_time_thread_(nullptr), running
SSLClient::~SSLClient() { CleanSSL(); } SSLClient::~SSLClient() { CleanSSL(); }
void SSLClient::InitSSL() { void SSLClient::InitSSL() {
CommUtil::InitOpenSSLEnv(); if (!SSL_library_init()) {
MS_LOG(EXCEPTION) << "SSL_library_init failed.";
}
if (!ERR_load_crypto_strings()) {
MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
}
if (!SSL_load_error_strings()) {
MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
}
if (!OpenSSL_add_all_algorithms()) {
MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
}
ssl_ctx_ = SSL_CTX_new(SSLv23_client_method()); ssl_ctx_ = SSL_CTX_new(SSLv23_client_method());
if (!ssl_ctx_) { if (!ssl_ctx_) {
MS_LOG(EXCEPTION) << "SSL_CTX_new failed"; MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
@ -65,6 +77,7 @@ void SSLClient::InitSSL() {
EVP_PKEY *pkey = nullptr; EVP_PKEY *pkey = nullptr;
X509 *cert = nullptr; X509 *cert = nullptr;
STACK_OF(X509) *ca_stack = nullptr; STACK_OF(X509) *ca_stack = nullptr;
MS_LOG(INFO) << "cliet cert: " << client_cert;
BIO *bio = BIO_new_file(client_cert.c_str(), "rb"); BIO *bio = BIO_new_file(client_cert.c_str(), "rb");
MS_EXCEPTION_IF_NULL(bio); MS_EXCEPTION_IF_NULL(bio);
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr); PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
@ -83,27 +96,26 @@ void SSLClient::InitSSL() {
} }
// 3. load ca cert. // 3. load ca cert.
std::string client_ca = kCAcrt;
std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath); std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
if (!CommUtil::IsFileExists(ca_path)) { if (!CommUtil::IsFileExists(ca_path)) {
MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist."; MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
} }
client_ca = ca_path; BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
MS_EXCEPTION_IF_NULL(ca_bio);
X509 *caCert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath); std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
if (crl_path.empty()) { if (crl_path.empty()) {
MS_LOG(INFO) << "The crl path is empty."; MS_LOG(INFO) << "The crl path is empty.";
} else if (!CommUtil::checkCRLTime(crl_path)) {
MS_LOG(EXCEPTION) << "check crl time failed";
} else if (!CommUtil::VerifyCRL(cert, crl_path)) { } else if (!CommUtil::VerifyCRL(cert, crl_path)) {
MS_LOG(EXCEPTION) << "Verify crl failed."; MS_LOG(EXCEPTION) << "Verify crl failed.";
} }
if (!CommUtil::VerifyCommonName(cert, client_ca)) { CommUtil::verifyCertPipeline(caCert, cert);
MS_LOG(EXCEPTION) << "Verify common name failed.";
}
SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 0); SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 0);
if (!SSL_CTX_load_verify_locations(ssl_ctx_, ca_path.c_str(), nullptr)) {
if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) {
MS_LOG(EXCEPTION) << "SSL load ca location failed!"; MS_LOG(EXCEPTION) << "SSL load ca location failed!";
} }
@ -115,13 +127,21 @@ void SSLClient::InitSSL() {
if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) { if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
MS_LOG(EXCEPTION) << "SSL use set cipher list failed!"; MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
} }
InitSSLCtx(cert, pkey);
StartCheckCertTime(*config_, cert);
// 4. load client cert EVP_PKEY_free(pkey);
if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) { X509_free(caCert);
X509_free(cert);
(void)BIO_free(ca_bio);
}
void SSLClient::InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey) {
if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast<X509 *>(cert))) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
} }
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) { if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast<EVP_PKEY *>(pkey))) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!"; MS_LOG(EXCEPTION) << "SSL use private key file failed!";
} }
@ -138,7 +158,7 @@ void SSLClient::InitSSL() {
MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!"; MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
} }
StartCheckCertTime(*config_, cert); SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel);
} }
void SSLClient::CleanSSL() { void SSLClient::CleanSSL() {

View File

@ -37,6 +37,7 @@
#include "ps/core/comm_util.h" #include "ps/core/comm_util.h"
#include "ps/constants.h" #include "ps/constants.h"
#include "ps/core/file_configuration.h" #include "ps/core/file_configuration.h"
#include "ps/ps_context.h"
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
@ -60,6 +61,7 @@ class SSLClient {
void StartCheckCertTime(const Configuration &config, const X509 *cert); void StartCheckCertTime(const Configuration &config, const X509 *cert);
void StopCheckCertTime(); void StopCheckCertTime();
void InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey);
SSL_CTX *ssl_ctx_; SSL_CTX *ssl_ctx_;
std::unique_ptr<std::thread> check_time_thread_; std::unique_ptr<std::thread> check_time_thread_;

View File

@ -1,3 +1,4 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
@ -35,7 +36,18 @@ SSLHTTP::SSLHTTP() : ssl_ctx_(nullptr) { InitSSL(); }
SSLHTTP::~SSLHTTP() { CleanSSL(); } SSLHTTP::~SSLHTTP() { CleanSSL(); }
void SSLHTTP::InitSSL() { void SSLHTTP::InitSSL() {
CommUtil::InitOpenSSLEnv(); if (!SSL_library_init()) {
MS_LOG(EXCEPTION) << "SSL_library_init failed.";
}
if (!ERR_load_crypto_strings()) {
MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
}
if (!SSL_load_error_strings()) {
MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
}
if (!OpenSSL_add_all_algorithms()) {
MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
}
ssl_ctx_ = SSL_CTX_new(SSLv23_server_method()); ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
if (!ssl_ctx_) { if (!ssl_ctx_) {
MS_LOG(EXCEPTION) << "SSL_CTX_new failed"; MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
@ -79,25 +91,39 @@ void SSLHTTP::InitSSL() {
MS_LOG(EXCEPTION) << "PKCS12_parse failed."; MS_LOG(EXCEPTION) << "PKCS12_parse failed.";
} }
PKCS12_free(p12); PKCS12_free(p12);
if (cert == nullptr) {
MS_LOG(EXCEPTION) << "the cert is nullptr";
}
if (pkey == nullptr) {
MS_LOG(EXCEPTION) << "the key is nullptr";
}
if (!CommUtil::verifyCertTimeStamp(cert)) {
MS_LOG(EXCEPTION) << "Verify Cert Time failed.";
}
std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList); std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList);
InitSSLCtx(cert, pkey, default_cipher_list);
EVP_PKEY_free(pkey);
X509_free(cert);
}
void SSLHTTP::InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey, const std::string &default_cipher_list) {
if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) { if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
MS_LOG(EXCEPTION) << "SSL use set cipher list failed!"; MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
} }
if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast<X509 *>(cert))) {
if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
} }
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) { if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast<EVP_PKEY *>(pkey))) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!"; MS_LOG(EXCEPTION) << "SSL use private key file failed!";
} }
if (!SSL_CTX_check_private_key(ssl_ctx_)) { if (!SSL_CTX_check_private_key(ssl_ctx_)) {
MS_LOG(EXCEPTION) << "SSL check private key file failed!"; MS_LOG(EXCEPTION) << "SSL check private key file failed!";
} }
if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) { SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed."; MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
} }
SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel);
} }
void SSLHTTP::CleanSSL() { void SSLHTTP::CleanSSL() {

View File

@ -53,7 +53,7 @@ class SSLHTTP {
void InitSSL(); void InitSSL();
void CleanSSL(); void CleanSSL();
void InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey, const std::string &default_cipher_list);
SSL_CTX *ssl_ctx_; SSL_CTX *ssl_ctx_;
}; };
} // namespace core } // namespace core

View File

@ -1,3 +1,4 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
@ -43,7 +44,18 @@ SSLWrapper::SSLWrapper()
SSLWrapper::~SSLWrapper() { CleanSSL(); } SSLWrapper::~SSLWrapper() { CleanSSL(); }
void SSLWrapper::InitSSL() { void SSLWrapper::InitSSL() {
CommUtil::InitOpenSSLEnv(); if (!SSL_library_init()) {
MS_LOG(EXCEPTION) << "SSL_library_init failed.";
}
if (!ERR_load_crypto_strings()) {
MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
}
if (!SSL_load_error_strings()) {
MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
}
if (!OpenSSL_add_all_algorithms()) {
MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
}
ssl_ctx_ = SSL_CTX_new(SSLv23_server_method()); ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
if (!ssl_ctx_) { if (!ssl_ctx_) {
MS_LOG(EXCEPTION) << "SSL_CTX_new failed"; MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
@ -100,31 +112,42 @@ void SSLWrapper::InitSSL() {
std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath); std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
if (crl_path.empty()) { if (crl_path.empty()) {
MS_LOG(INFO) << "The crl path is empty."; MS_LOG(INFO) << "The crl path is empty.";
} else if (!CommUtil::checkCRLTime(crl_path)) {
MS_LOG(EXCEPTION) << "check crl time failed";
} else if (!CommUtil::VerifyCRL(cert, crl_path)) { } else if (!CommUtil::VerifyCRL(cert, crl_path)) {
MS_LOG(EXCEPTION) << "Verify crl failed."; MS_LOG(EXCEPTION) << "Verify crl failed.";
} }
std::string client_ca = kCAcrt;
std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath); std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
if (!CommUtil::IsFileExists(ca_path)) { if (!CommUtil::IsFileExists(ca_path)) {
MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist."; MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
} }
client_ca = ca_path; BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
MS_EXCEPTION_IF_NULL(ca_bio);
X509 *caCert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
if (!CommUtil::VerifyCommonName(cert, client_ca)) { CommUtil::verifyCertPipeline(caCert, cert);
MS_LOG(EXCEPTION) << "Verify common name failed.";
}
SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER, 0); SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER, 0);
if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) { if (!SSL_CTX_load_verify_locations(ssl_ctx_, ca_path.c_str(), nullptr)) {
MS_LOG(EXCEPTION) << "SSL load ca location failed!"; MS_LOG(EXCEPTION) << "SSL load ca location failed!";
} }
if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) { InitSSLCtx(cert, pkey);
StartCheckCertTime(*config_, cert, ca_path);
EVP_PKEY_free(pkey);
X509_free(caCert);
X509_free(cert);
(void)BIO_free(ca_bio);
}
void SSLWrapper::InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey) {
if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast<X509 *>(cert))) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
} }
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) { if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast<EVP_PKEY *>(pkey))) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!"; MS_LOG(EXCEPTION) << "SSL use private key file failed!";
} }
@ -135,12 +158,10 @@ void SSLWrapper::InitSSL() {
SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) { SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed."; MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
} }
if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) { if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) {
MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!"; MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
} }
SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel);
StartCheckCertTime(*config_, cert, client_ca);
} }
void SSLWrapper::CleanSSL() { void SSLWrapper::CleanSSL() {

View File

@ -60,6 +60,7 @@ class SSLWrapper {
time_t ConvertAsn1Time(const ASN1_TIME *const time) const; time_t ConvertAsn1Time(const ASN1_TIME *const time) const;
void StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path); void StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path);
void StopCheckCertTime(); void StopCheckCertTime();
void InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey);
SSL_CTX *ssl_ctx_; SSL_CTX *ssl_ctx_;

View File

@ -94,7 +94,6 @@ void TcpClient::Init() {
if (event_base_ == nullptr) { if (event_base_ == nullptr) {
event_base_ = event_base_new(); event_base_ = event_base_new();
MS_EXCEPTION_IF_NULL(event_base_); MS_EXCEPTION_IF_NULL(event_base_);
is_stop_ = false;
} }
sockaddr_in sin{}; sockaddr_in sin{};
@ -160,7 +159,7 @@ void TcpClient::Stop() {
void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) { void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
const int one = 1; const int one = 1;
int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int)); int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
if (ret < 0) { if (ret < 0) {
MS_LOG(EXCEPTION) << "Set socket no delay failed!"; MS_LOG(EXCEPTION) << "Set socket no delay failed!";
} }
@ -178,10 +177,10 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
auto tcp_client = reinterpret_cast<TcpClient *>(ctx); auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
char read_buffer[kMessageChunkLength]; char read_buffer[kMessageChunkLength];
size_t read = 0; int read = 0;
while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) { while ((read = bufferevent_read(bev, &read_buffer, SizeToInt(sizeof(read_buffer)))) > 0) {
tcp_client->OnReadHandler(read_buffer, read); tcp_client->OnReadHandler(read_buffer, IntToSize(read));
} }
} }
@ -252,6 +251,8 @@ void TcpClient::Start() {
event_base_mutex_.unlock(); event_base_mutex_.unlock();
MS_EXCEPTION_IF_NULL(event_base_); MS_EXCEPTION_IF_NULL(event_base_);
int ret = event_base_dispatch(event_base_); int ret = event_base_dispatch(event_base_);
// is_started_ should be false when finish dispatch
is_started_ = false;
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
<< "Event base dispatch failed with no events pending or active!"; << "Event base dispatch failed with no events pending or active!";

View File

@ -90,6 +90,7 @@ bool TcpCommunicator::Stop() {
} }
void TcpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) { void TcpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) {
MS_LOG(INFO) << "msg_type is: " << msg_type;
msg_callbacks_.try_emplace(msg_type, cb); msg_callbacks_.try_emplace(msg_type, cb);
return; return;
} }

View File

@ -21,7 +21,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <memory> #include <memory>
#include "utils/hash_map.h" #include <unordered_map>
#include "proto/ps.pb.h" #include "proto/ps.pb.h"
#include "ps/core/server_node.h" #include "ps/core/server_node.h"
#include "ps/core/cluster_metadata.h" #include "ps/core/cluster_metadata.h"
@ -36,7 +36,7 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
const mindspore::HashMap<TcpUserCommand, std::string> kUserCommandToMsgType = { const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
{TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPush, "push"},
{TcpUserCommand::kPull, "pull"}, {TcpUserCommand::kPull, "pull"},
{TcpUserCommand::kCount, "count"}, {TcpUserCommand::kCount, "count"},
@ -61,7 +61,8 @@ const mindspore::HashMap<TcpUserCommand, std::string> kUserCommandToMsgType = {
{TcpUserCommand::kNewInstance, "newInstance"}, {TcpUserCommand::kNewInstance, "newInstance"},
{TcpUserCommand::kQueryInstance, "queryInstance"}, {TcpUserCommand::kQueryInstance, "queryInstance"},
{TcpUserCommand::kEnableFLS, "enableFLS"}, {TcpUserCommand::kEnableFLS, "enableFLS"},
{TcpUserCommand::kDisableFLS, "disableFLS"}}; {TcpUserCommand::kDisableFLS, "disableFLS"},
{TcpUserCommand::kSyncAfterRecover, "syncAfterRecover"}};
class TcpCommunicator : public CommunicatorBase { class TcpCommunicator : public CommunicatorBase {
public: public:
@ -92,8 +93,9 @@ class TcpCommunicator : public CommunicatorBase {
MS_ERROR_IF_NULL_W_RET_VAL(msg, false); MS_ERROR_IF_NULL_W_RET_VAL(msg, false);
size_t dest_size = msg_str.size(); size_t dest_size = msg_str.size();
size_t src_size = msg_str.size(); size_t src_size = msg_str.size();
if (memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size) != EOK) { auto ret = memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size);
MS_LOG(EXCEPTION) << "Memcpy_s error"; if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy_s error, error no " << ret;
} }
if (output != nullptr) { if (output != nullptr) {

View File

@ -20,8 +20,8 @@
namespace mindspore { namespace mindspore {
namespace ps { namespace ps {
namespace core { namespace core {
TcpMsgHandler::TcpMsgHandler(AbstractNode *const abstract_node, const std::shared_ptr<core::TcpConnection> &conn, TcpMsgHandler::TcpMsgHandler(AbstractNode *abstract_node, const std::shared_ptr<core::TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const DataPtr &data, size_t size) const std::shared_ptr<MessageMeta> &meta, DataPtr data, size_t size)
: abstract_node_(abstract_node), tcp_conn_(conn), meta_(meta), data_ptr_(data), data_(nullptr), len_(size) { : abstract_node_(abstract_node), tcp_conn_(conn), meta_(meta), data_ptr_(data), data_(nullptr), len_(size) {
if (data_ptr_ != nullptr) { if (data_ptr_ != nullptr) {
data_ = data_ptr_.get(); data_ = data_ptr_.get();

View File

@ -28,8 +28,8 @@ namespace ps {
namespace core { namespace core {
class TcpMsgHandler : public MessageHandler { class TcpMsgHandler : public MessageHandler {
public: public:
TcpMsgHandler(AbstractNode *const abstract_node, const std::shared_ptr<core::TcpConnection> &conn, TcpMsgHandler(AbstractNode *abstract_node, const std::shared_ptr<core::TcpConnection> &conn,
const std::shared_ptr<MessageMeta> &meta, const DataPtr &data, size_t size); const std::shared_ptr<MessageMeta> &meta, DataPtr data, size_t size);
~TcpMsgHandler() override = default; ~TcpMsgHandler() override = default;
void *data() const override; void *data() const override;

View File

@ -175,8 +175,9 @@ void TcpServer::Init() {
listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this), listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1, LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin)); reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
if (listener_ == nullptr) {
MS_EXCEPTION_IF_NULL(listener_); MS_LOG(EXCEPTION) << "bind ip & port failed. please check.";
}
if (server_port_ == 0) { if (server_port_ == 0) {
struct sockaddr_in sin_bound {}; struct sockaddr_in sin_bound {};
@ -306,6 +307,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
}); });
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
reinterpret_cast<void *>(conn.get())); reinterpret_cast<void *>(conn.get()));
MS_LOG(INFO) << "A client is connected, fd is " << fd;
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!"; MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
} }
@ -411,7 +413,7 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) { void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) {
const int one = 1; const int one = 1;
int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int)); int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
if (ret < 0) { if (ret < 0) {
MS_LOG(EXCEPTION) << "Set socket no delay failed!"; MS_LOG(EXCEPTION) << "Set socket no delay failed!";
} }

View File

@ -25,9 +25,10 @@
#include <vector> #include <vector>
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include <unordered_map>
#include "utils/hash_map.h"
#include "ps/constants.h" #include "ps/constants.h"
#include "nlohmann/json.hpp"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
@ -51,6 +52,9 @@ class Configuration {
// Get configuration data from database or config file. // Get configuration data from database or config file.
virtual std::string GetString(const std::string &key, const std::string &defaultvalue) const = 0; virtual std::string GetString(const std::string &key, const std::string &defaultvalue) const = 0;
// Get configuration vector data from database or config file.
virtual std::vector<nlohmann::json> GetVector(const std::string &key) const = 0;
// Get configuration data from database or config file. // Get configuration data from database or config file.
virtual int64_t GetInt(const std::string &key, int64_t default_value) const = 0; virtual int64_t GetInt(const std::string &key, int64_t default_value) const = 0;
@ -59,6 +63,12 @@ class Configuration {
// Determine whether the configuration item exists. // Determine whether the configuration item exists.
virtual bool Exists(const std::string &key) const = 0; virtual bool Exists(const std::string &key) const = 0;
// storage meta data
virtual void PersistFile(const core::ClusterConfig &clusterConfig) const = 0;
// storage meta data without nodes
virtual void PersistNodes(const core::ClusterConfig &clusterConfig) const = 0;
}; };
} // namespace core } // namespace core
} // namespace ps } // namespace ps

View File

@ -50,6 +50,15 @@ std::string FileConfiguration::Get(const std::string &key, const std::string &de
return res; return res;
} }
std::vector<nlohmann::json> FileConfiguration::GetVector(const std::string &key) const {
if (!js.contains(key)) {
MS_LOG(WARNING) << "The key:" << key << " is not exist.";
return std::vector<nlohmann::json>();
}
return js.at(key);
}
std::string FileConfiguration::GetString(const std::string &key, const std::string &defaultvalue) const { std::string FileConfiguration::GetString(const std::string &key, const std::string &defaultvalue) const {
if (!js.contains(key)) { if (!js.contains(key)) {
MS_LOG(WARNING) << "The key:" << key << " is not exist."; MS_LOG(WARNING) << "The key:" << key << " is not exist.";
@ -68,13 +77,7 @@ int64_t FileConfiguration::GetInt(const std::string &key, int64_t default_value)
return res; return res;
} }
void FileConfiguration::Put(const std::string &key, const std::string &value) { void FileConfiguration::Put(const std::string &key, const std::string &value) { js[key] = value; }
std::ofstream output_file(file_path_);
js[key] = value;
output_file << js.dump();
output_file.close();
}
bool FileConfiguration::Exists(const std::string &key) const { bool FileConfiguration::Exists(const std::string &key) const {
if (!js.contains(key)) { if (!js.contains(key)) {
@ -82,6 +85,53 @@ bool FileConfiguration::Exists(const std::string &key) const {
} }
return true; return true;
} }
void FileConfiguration::PersistNodes(const core::ClusterConfig &clusterConfig) const {
if (!CommUtil::IsFileExists(file_path_)) {
MS_LOG(WARNING) << "The file path:" << file_path_ << " is not exist. create one";
}
nlohmann::json persist_js;
persist_js[kRecoveryTotalNodeNum] = clusterConfig.initial_total_node_num;
persist_js[kRecoveryNextWorkerRankId] = clusterConfig.initial_next_worker_rank_id;
persist_js[kRecoveryNextServerRankId] = clusterConfig.initial_next_server_rank_id;
auto node_infos = clusterConfig.initial_registered_nodes_infos;
for (const auto kvs : node_infos) {
std::unordered_map<std::string, std::string> res;
res["ip"] = kvs.second.ip_;
res["port"] = std::to_string(kvs.second.port_);
res["node_id"] = kvs.second.node_id_;
res["rank_id"] = std::to_string(kvs.second.rank_id_);
res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_);
res["alive"] = CommUtil::BoolToString(kvs.second.is_alive);
persist_js["node_ids"].push_back(res);
}
std::ofstream output_file(file_path_);
output_file << persist_js.dump();
output_file.close();
MS_LOG(INFO) << "The nodes meta data persist to " << file_path_;
}
void FileConfiguration::PersistFile(const core::ClusterConfig &clusterConfig) const {
if (!CommUtil::IsFileExists(file_path_)) {
MS_LOG(WARNING) << "The file path:" << file_path_ << " is not exist. create one";
}
nlohmann::json persist_js;
persist_js[kRecoveryWorkerNum] = clusterConfig.initial_worker_num;
persist_js[kRecoveryServerNum] = clusterConfig.initial_server_num;
persist_js[kRecoverySchedulerIp] = clusterConfig.scheduler_host;
persist_js[kRecoverySchedulerPort] = clusterConfig.scheduler_port;
std::ofstream output_file(file_path_);
output_file << persist_js.dump();
output_file.close();
MS_LOG(INFO) << "The meta data persist to " << file_path_;
}
} // namespace core } // namespace core
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More