forked from mindspore-Ecosystem/mindspore
!27108 code sync in federated learning
Merge pull request !27108 from tan-wei-cheng-3260/develop-twc-sync2
This commit is contained in:
commit
82542f1d4a
|
@ -30,7 +30,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// The duration between two PullWeights requests when return code is ResponseCode_SucNotReady.
|
// The duration between two PullWeight requests when return code is ResponseCode_SucNotReady.
|
||||||
constexpr int kRetryDurationOfPullWeights = 200;
|
constexpr int kRetryDurationOfPullWeights = 200;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class FusedPullWeightKernel : public CPUKernel {
|
class FusedPullWeightKernel : public CPUKernel {
|
||||||
|
@ -52,11 +52,11 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
total_iteration_++;
|
total_iteration_++;
|
||||||
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
|
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
|
||||||
if (step_num_per_iteration == 0) {
|
if (step_num_per_iteration == 0) {
|
||||||
MS_LOG(EXCEPTION) << "Step numbers of per iteration should not be equal to 0";
|
MS_LOG(EXCEPTION) << "step number per iteration should not be 0";
|
||||||
}
|
}
|
||||||
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
|
||||||
MS_LOG(INFO) << "Try to pull weights. Local step number: " << total_iteration_
|
MS_LOG(INFO) << "Try to pull weights. Local step number: " << total_iteration_
|
||||||
<< ", step number needs to run per iteration: " << step_num_per_iteration;
|
<< ", step number needs to run per iteration: " << step_num_per_iteration;
|
||||||
|
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
||||||
if (step_num_per_iteration != fl::kOneStepPerIteration &&
|
if (step_num_per_iteration != fl::kOneStepPerIteration &&
|
||||||
total_iteration_ % step_num_per_iteration != fl::kTrainBeginStepNum) {
|
total_iteration_ % step_num_per_iteration != fl::kTrainBeginStepNum) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -86,6 +86,7 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
|
MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg);
|
||||||
|
|
||||||
pull_weight_rsp = flatbuffers::GetRoot<schema::ResponsePullWeight>(pull_weight_rsp_msg->data());
|
pull_weight_rsp = flatbuffers::GetRoot<schema::ResponsePullWeight>(pull_weight_rsp_msg->data());
|
||||||
|
MS_EXCEPTION_IF_NULL(pull_weight_rsp);
|
||||||
retcode = pull_weight_rsp->retcode();
|
retcode = pull_weight_rsp->retcode();
|
||||||
if (retcode == schema::ResponseCode_SucNotReady) {
|
if (retcode == schema::ResponseCode_SucNotReady) {
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPullWeights));
|
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPullWeights));
|
||||||
|
@ -95,12 +96,12 @@ class FusedPullWeightKernel : public CPUKernel {
|
||||||
// Recreate fbb to avoid memory leak of FlatBuffers.
|
// Recreate fbb to avoid memory leak of FlatBuffers.
|
||||||
fbb = std::make_shared<fl::FBBuilder>();
|
fbb = std::make_shared<fl::FBBuilder>();
|
||||||
if (!BuildPullWeightReq(fbb)) {
|
if (!BuildPullWeightReq(fbb)) {
|
||||||
MS_LOG(EXCEPTION) << "Building request for FusedDownloadWeightsByKeys failed.";
|
MS_LOG(EXCEPTION) << "Building request for FusedPullWeight failed.";
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
} else if (retcode != schema::ResponseCode_SUCCEED) {
|
} else if (retcode != schema::ResponseCode_SUCCEED) {
|
||||||
MS_LOG(EXCEPTION) << "FusedPullWeight failed. Server return code: " << pull_weight_rsp->retcode()
|
MS_LOG(WARNING) << "FusedPullWeight failed. Server return code: " << pull_weight_rsp->retcode()
|
||||||
<< ", reason: " << pull_weight_rsp->reason()->str();
|
<< ", reason: " << pull_weight_rsp->reason()->str();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "FusedPullWeight succeed.";
|
MS_LOG(DEBUG) << "FusedPullWeight succeed.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// The duration between two PushWeights requests when return code is ResponseCode_SucNotReady.
|
// The duration between two PushWeight requests when return code is ResponseCode_SucNotReady.
|
||||||
constexpr int kRetryDurationOfPushWeights = 200;
|
constexpr int kRetryDurationOfPushWeights = 200;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class FusedPushWeightKernel : public CPUKernel {
|
class FusedPushWeightKernel : public CPUKernel {
|
||||||
|
@ -50,11 +50,11 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
total_iteration_++;
|
total_iteration_++;
|
||||||
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
|
uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration();
|
||||||
if (step_num_per_iteration == 0) {
|
if (step_num_per_iteration == 0) {
|
||||||
MS_LOG(EXCEPTION) << "Step numbers of per iteration should not be equal to 0";
|
MS_LOG(EXCEPTION) << "step number per iterationb should not be 0";
|
||||||
}
|
}
|
||||||
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
|
||||||
MS_LOG(INFO) << "Try to push weights. Local step number: " << total_iteration_
|
MS_LOG(INFO) << "Try to push weights. Local step number: " << total_iteration_
|
||||||
<< ", step number needs to run per iteration: " << step_num_per_iteration;
|
<< ", step number needs to run per iteration: " << step_num_per_iteration;
|
||||||
|
// The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server.
|
||||||
if (step_num_per_iteration != fl::kOneStepPerIteration &&
|
if (step_num_per_iteration != fl::kOneStepPerIteration &&
|
||||||
total_iteration_ % step_num_per_iteration != fl::kTrainEndStepNum) {
|
total_iteration_ % step_num_per_iteration != fl::kTrainEndStepNum) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -87,6 +87,7 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
|
MS_EXCEPTION_IF_NULL(push_weight_rsp_msg);
|
||||||
|
|
||||||
push_weight_rsp = flatbuffers::GetRoot<schema::ResponsePushWeight>(push_weight_rsp_msg->data());
|
push_weight_rsp = flatbuffers::GetRoot<schema::ResponsePushWeight>(push_weight_rsp_msg->data());
|
||||||
|
MS_EXCEPTION_IF_NULL(push_weight_rsp);
|
||||||
retcode = push_weight_rsp->retcode();
|
retcode = push_weight_rsp->retcode();
|
||||||
if (retcode == schema::ResponseCode_SucNotReady) {
|
if (retcode == schema::ResponseCode_SucNotReady) {
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushWeights));
|
std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushWeights));
|
||||||
|
@ -98,8 +99,8 @@ class FusedPushWeightKernel : public CPUKernel {
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
} else if (retcode != schema::ResponseCode_SUCCEED) {
|
} else if (retcode != schema::ResponseCode_SUCCEED) {
|
||||||
MS_LOG(EXCEPTION) << "FusedPushWeight failed. Server return code: " << push_weight_rsp->retcode()
|
MS_LOG(WARNING) << "FusedPushWeight failed. Server return code: " << push_weight_rsp->retcode()
|
||||||
<< ", reason: " << push_weight_rsp->reason()->str();
|
<< ", reason: " << push_weight_rsp->reason()->str();
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "FusedPushWeight succeed.";
|
MS_LOG(DEBUG) << "FusedPushWeight succeed.";
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,7 +87,7 @@ class PushMetricsKernel : public CPUKernel {
|
||||||
case schema::ResponseCode_OutOfTime:
|
case schema::ResponseCode_OutOfTime:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(EXCEPTION) << "Launching push metrics for worker failed.";
|
MS_LOG(WARNING) << "Launching push metrics for worker failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Push metrics for loss and accuracy success.";
|
MS_LOG(INFO) << "Push metrics for loss and accuracy success.";
|
||||||
|
|
|
@ -5,7 +5,6 @@ if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/aggregation_kernel_factory.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/fed_avg_kernel.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/fed_avg_kernel.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/sgd_kernel.cc")
|
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/optimizer_kernel_factory.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel_factory.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel.cc")
|
||||||
|
@ -20,6 +19,8 @@ if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_secrets_kernel.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_secrets_kernel.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/reconstruct_secrets_kernel.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/reconstruct_secrets_kernel.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/share_secrets_kernel.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/share_secrets_kernel.cc")
|
||||||
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_list_sign_kernel.cc")
|
||||||
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_list_sign_kernel.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_metrics_kernel.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_metrics_kernel.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/params_info.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/params_info.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/consistent_hash_ring.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/consistent_hash_ring.cc")
|
||||||
|
@ -35,6 +36,8 @@ if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/model_store.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/model_store.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/round.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/round.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/server.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/server.cc")
|
||||||
|
list(REMOVE_ITEM _FL_SRC_FILES "server/cert_verify.cc")
|
||||||
|
list(REMOVE_ITEM _FL_SRC_FILES "server/server_recovery.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "server/iteration_metrics.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "server/iteration_metrics.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc")
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc")
|
||||||
|
@ -49,10 +52,6 @@ if(NOT ENABLE_CPU OR WIN32)
|
||||||
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc")
|
list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-tautological-pointer-compare")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
list(LENGTH _FL_SRC_FILES fl_file_num)
|
list(LENGTH _FL_SRC_FILES fl_file_num)
|
||||||
if(NOT fl_file_num EQUAL 0)
|
if(NOT fl_file_num EQUAL 0)
|
||||||
set_property(SOURCE ${_FL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_FL)
|
set_property(SOURCE ${_FL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_FL)
|
||||||
|
|
|
@ -24,8 +24,8 @@ namespace mindspore {
|
||||||
namespace armour {
|
namespace armour {
|
||||||
bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
|
bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
|
||||||
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
|
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
|
||||||
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
|
size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt,
|
||||||
size_t cipher_reconstruct_secrets_up_cnt) {
|
size_t cipher_get_list_sign_cnt, size_t cipher_reconstruct_secrets_up_cnt) {
|
||||||
MS_LOG(INFO) << "CipherInit::Init START";
|
MS_LOG(INFO) << "CipherInit::Init START";
|
||||||
if (publicparam_.p == nullptr || param.p == nullptr || param.prime == nullptr || publicparam_.prime == nullptr) {
|
if (publicparam_.p == nullptr || param.p == nullptr || param.prime == nullptr || publicparam_.prime == nullptr) {
|
||||||
MS_LOG(ERROR) << "CipherInit::input data invalid.";
|
MS_LOG(ERROR) << "CipherInit::input data invalid.";
|
||||||
|
@ -47,6 +47,8 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
||||||
get_secrets_threshold = cipher_get_secrets_cnt;
|
get_secrets_threshold = cipher_get_secrets_cnt;
|
||||||
client_list_threshold = cipher_get_clientlist_cnt;
|
client_list_threshold = cipher_get_clientlist_cnt;
|
||||||
reconstruct_secrets_threshold = cipher_reconstruct_secrets_up_cnt;
|
reconstruct_secrets_threshold = cipher_reconstruct_secrets_up_cnt;
|
||||||
|
push_list_sign_threshold = cipher_push_list_sign_cnt;
|
||||||
|
get_list_sign_threshold = cipher_get_list_sign_cnt;
|
||||||
|
|
||||||
time_out_mutex_ = time_out_mutex;
|
time_out_mutex_ = time_out_mutex;
|
||||||
publicparam_.dp_eps = param.dp_eps;
|
publicparam_.dp_eps = param.dp_eps;
|
||||||
|
@ -74,6 +76,8 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
||||||
MS_LOG(INFO) << " CipherInit get_secrets_threshold : " << get_secrets_threshold;
|
MS_LOG(INFO) << " CipherInit get_secrets_threshold : " << get_secrets_threshold;
|
||||||
MS_LOG(INFO) << " CipherInit client_list_threshold : " << client_list_threshold;
|
MS_LOG(INFO) << " CipherInit client_list_threshold : " << client_list_threshold;
|
||||||
MS_LOG(INFO) << " CipherInit reconstruct_secrets_threshold : " << reconstruct_secrets_threshold;
|
MS_LOG(INFO) << " CipherInit reconstruct_secrets_threshold : " << reconstruct_secrets_threshold;
|
||||||
|
MS_LOG(INFO) << " CipherInit push_list_sign_threshold : " << push_list_sign_threshold;
|
||||||
|
MS_LOG(INFO) << " CipherInit get_list_sign_threshold : " << get_list_sign_threshold;
|
||||||
MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_;
|
MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_;
|
||||||
if (!Check_Parames()) {
|
if (!Check_Parames()) {
|
||||||
MS_LOG(ERROR) << "Cipher parameters are illegal.";
|
MS_LOG(ERROR) << "Cipher parameters are illegal.";
|
||||||
|
@ -81,7 +85,6 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << " CipherInit::Init Success";
|
MS_LOG(INFO) << " CipherInit::Init Success";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (param.encrypt_type == mindspore::ps::kStablePWEncryptType) {
|
if (param.encrypt_type == mindspore::ps::kStablePWEncryptType) {
|
||||||
cipher_meta_storage_.RegisterStablePWClass();
|
cipher_meta_storage_.RegisterStablePWClass();
|
||||||
MS_LOG(INFO) << "Register metadata for StablePWEncrypt is finished.";
|
MS_LOG(INFO) << "Register metadata for StablePWEncrypt is finished.";
|
||||||
|
@ -96,9 +99,10 @@ bool CipherInit::Check_Parames() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (share_secrets_threshold < reconstruct_secrets_threshold) {
|
if (share_secrets_threshold < reconstruct_secrets_threshold) {
|
||||||
MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger "
|
MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger than "
|
||||||
"than share_secrets_threshold, but got they are:"
|
"share_secrets_threshold."
|
||||||
<< reconstruct_secrets_threshold << ", " << share_secrets_threshold;
|
<< "reconstruct_secrets_threshold: " << reconstruct_secrets_threshold
|
||||||
|
<< ", share_secrets_threshold: " << share_secrets_threshold;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,27 +40,34 @@ class CipherInit {
|
||||||
// Initialize the parameters of the secure aggregation.
|
// Initialize the parameters of the secure aggregation.
|
||||||
bool Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
|
bool Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt,
|
||||||
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
|
size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt,
|
||||||
size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt,
|
size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt, size_t cipher_get_list_sign_cnt,
|
||||||
size_t cipher_reconstruct_secrets_up_cnt);
|
size_t cipher_reconstruct_secrets_up_cnt);
|
||||||
|
|
||||||
// Get public params. which is given to start fl job thread.
|
// Get public params. which is given to start fl job thread.
|
||||||
CipherPublicPara *GetPublicParams() { return &publicparam_; }
|
CipherPublicPara *GetPublicParams() { return &publicparam_; }
|
||||||
|
|
||||||
size_t share_secrets_threshold; // the minimum number of clients to share secret fragments.
|
size_t share_secrets_threshold; // the minimum number of clients to share
|
||||||
size_t reconstruct_secrets_threshold; // the minimum number of clients to reconstruct secret mask.
|
// secret fragments.
|
||||||
size_t exchange_key_threshold; // the minimum number of clients to send public keys.
|
size_t reconstruct_secrets_threshold; // the minimum number of clients to
|
||||||
size_t push_list_sign_threshold; // the minimum number of clients to push client list signature.
|
// reconstruct secret mask.
|
||||||
size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask.
|
size_t exchange_key_threshold; // the minimum number of clients to send public
|
||||||
|
// keys.
|
||||||
|
size_t push_list_sign_threshold; // the minimum number of clients to push
|
||||||
|
// client list signature.
|
||||||
|
size_t secrets_minnums_; // the minimum number of secret fragment s to
|
||||||
|
// reconstruct secret mask.
|
||||||
size_t featuremap_; // the size of data to deal.
|
size_t featuremap_; // the size of data to deal.
|
||||||
|
CipherPublicPara publicparam_; // the param containing encrypted public parameters.
|
||||||
CipherPublicPara publicparam_; // the param containing encrypted public parameters.
|
|
||||||
CipherMetaStorage cipher_meta_storage_;
|
CipherMetaStorage cipher_meta_storage_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t client_list_threshold; // the minimum number of clients to get update model client list.
|
size_t client_list_threshold; // the minimum number of clients to get update
|
||||||
|
// model client list.
|
||||||
size_t get_key_threshold; // the minimum number of clients to get public keys.
|
size_t get_key_threshold; // the minimum number of clients to get public keys.
|
||||||
size_t get_list_sign_threshold; // the minimum number of clients to get client list signature.
|
size_t get_list_sign_threshold; // the minimum number of clients to get client
|
||||||
size_t get_secrets_threshold; // the minimum number of clients to get secret fragments.
|
// list signature.
|
||||||
|
size_t get_secrets_threshold; // the minimum number of clients to get secret
|
||||||
|
// fragments.
|
||||||
size_t time_out_mutex_; // timeout mutex.
|
size_t time_out_mutex_; // timeout mutex.
|
||||||
|
|
||||||
// Check whether the parameters are valid.
|
// Check whether the parameters are valid.
|
||||||
|
|
|
@ -0,0 +1,633 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "fl/server/cert_verify.h"
|
||||||
|
#include <sys/time.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <vector>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
#ifndef _WIN32
|
||||||
|
static int64_t replayAttackTimeDiff;
|
||||||
|
X509 *CertVerify::readCertFromFile(const std::string &certPath) {
|
||||||
|
BIO *bio = BIO_new_file(certPath.c_str(), "r");
|
||||||
|
X509 *certObj = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
|
||||||
|
BIO_free_all(bio);
|
||||||
|
return certObj;
|
||||||
|
}
|
||||||
|
|
||||||
|
X509 *CertVerify::readCertFromPerm(std::string cert) {
|
||||||
|
BIO *bio = BIO_new_mem_buf(reinterpret_cast<void *>(cert.data()), -1);
|
||||||
|
X509 *certObj = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr);
|
||||||
|
BIO_free_all(bio);
|
||||||
|
return certObj;
|
||||||
|
}
|
||||||
|
|
||||||
|
X509_CRL *CertVerify::readCrlFromFile(const std::string &crlPath) {
|
||||||
|
BIO *bio = BIO_new_file(crlPath.c_str(), "r");
|
||||||
|
X509_CRL *crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
|
||||||
|
BIO_free_all(bio);
|
||||||
|
return crl;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool checkFileExists(const std::string &file) {
|
||||||
|
std::ifstream f(file.c_str());
|
||||||
|
if (!f.good()) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
f.close();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyCertTime(const X509 *cert) const {
|
||||||
|
ASN1_TIME *start = X509_getm_notBefore(cert);
|
||||||
|
ASN1_TIME *end = X509_getm_notAfter(cert);
|
||||||
|
|
||||||
|
int day = 0;
|
||||||
|
int sec = 0;
|
||||||
|
int ret = ASN1_TIME_diff(&day, &sec, start, NULL);
|
||||||
|
if (ret != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (day < 0 || sec < 0) {
|
||||||
|
MS_LOG(ERROR) << "cert start time is later than now time.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
day = 0;
|
||||||
|
sec = 0;
|
||||||
|
ret = ASN1_TIME_diff(&day, &sec, NULL, end);
|
||||||
|
if (ret != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (day < 0 || sec < 0) {
|
||||||
|
MS_LOG(ERROR) << "cert end time is sooner than now time.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyPublicKey(const X509 *keyAttestationCertObj, const X509 *equipCertObj,
|
||||||
|
const X509 *equipCACertObj, const X509 *rootFirstCA, const X509 *rootSecondCA) const {
|
||||||
|
bool result = true;
|
||||||
|
EVP_PKEY *equipPubKey = X509_get_pubkey(const_cast<X509 *>(equipCertObj));
|
||||||
|
EVP_PKEY *equipCAPubKey = X509_get_pubkey(const_cast<X509 *>(equipCACertObj));
|
||||||
|
EVP_PKEY *rootFirstPubKey = X509_get_pubkey(const_cast<X509 *>(rootFirstCA));
|
||||||
|
EVP_PKEY *rootSecondPubKey = X509_get_pubkey(const_cast<X509 *>(rootSecondCA));
|
||||||
|
do {
|
||||||
|
int ret = 0;
|
||||||
|
ret = X509_verify(const_cast<X509 *>(keyAttestationCertObj), equipPubKey);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "keyAttestationCert verify is failed";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ret = X509_verify(const_cast<X509 *>(equipCertObj), equipCAPubKey);
|
||||||
|
if (ret != 1) {
|
||||||
|
MS_LOG(ERROR) << "equip cert verify is failed";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
int ret_first = X509_verify(const_cast<X509 *>(equipCACertObj), rootFirstPubKey);
|
||||||
|
int ret_second = X509_verify(const_cast<X509 *>(equipCACertObj), rootSecondPubKey);
|
||||||
|
if (ret_first != 1 && ret_second != 1) {
|
||||||
|
MS_LOG(ERROR) << "equip ca cert verify is failed";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (0);
|
||||||
|
|
||||||
|
EVP_PKEY_free(equipPubKey);
|
||||||
|
EVP_PKEY_free(equipCAPubKey);
|
||||||
|
EVP_PKEY_free(rootFirstPubKey);
|
||||||
|
EVP_PKEY_free(rootSecondPubKey);
|
||||||
|
MS_LOG(INFO) << "verify Public Key success.";
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyCAChain(const std::string &keyAttestation, const std::string &equipCert,
|
||||||
|
const std::string &equipCACert, const std::string &rootFirstCAPath,
|
||||||
|
const std::string &rootSecondCAPath) {
|
||||||
|
X509 *rootFirstCA = CertVerify::readCertFromFile(rootFirstCAPath);
|
||||||
|
X509 *rootSecondCA = CertVerify::readCertFromFile(rootSecondCAPath);
|
||||||
|
X509 *keyAttestationCertObj = readCertFromPerm(keyAttestation);
|
||||||
|
X509 *equipCertObj = readCertFromPerm(equipCert);
|
||||||
|
X509 *equipCACertObj = readCertFromPerm(equipCACert);
|
||||||
|
bool result = true;
|
||||||
|
do {
|
||||||
|
if (rootFirstCA == nullptr || rootSecondCA == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "rootFirstCA or rootSecondCA is nullptr";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (keyAttestationCertObj == nullptr || equipCertObj == nullptr || equipCACertObj == nullptr) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyCertTime(keyAttestationCertObj) || !verifyCertTime(equipCertObj) || !verifyCertTime(equipCACertObj)) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyCertCommonName(equipCACertObj, equipCertObj)) {
|
||||||
|
MS_LOG(ERROR) << "equip ca cert subject cn is not equal with equip cert issuer cn.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyCertCommonName(rootFirstCA, equipCACertObj) && !verifyCertCommonName(rootSecondCA, equipCACertObj)) {
|
||||||
|
MS_LOG(ERROR) << "root CA cert subject cn is not equal with equip CA cert issuer cn.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyExtendedAttributes(equipCACertObj)) {
|
||||||
|
MS_LOG(ERROR) << "verify equipCACert Extended Attributes failed.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyCertKeyID(rootFirstCA, equipCACertObj) && !verifyCertKeyID(rootSecondCA, equipCACertObj)) {
|
||||||
|
MS_LOG(ERROR) << "root CA cert subject keyid is not equal with equip CA cert issuer keyid.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyCertKeyID(equipCACertObj, equipCertObj)) {
|
||||||
|
MS_LOG(ERROR) << "equip CA cert subject keyid is not equal with equip cert issuer keyid.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyPublicKey(keyAttestationCertObj, equipCertObj, equipCACertObj, rootFirstCA, rootSecondCA)) {
|
||||||
|
MS_LOG(ERROR) << "verify Public Key failed";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (0);
|
||||||
|
X509_free(rootFirstCA);
|
||||||
|
X509_free(rootSecondCA);
|
||||||
|
X509_free(keyAttestationCertObj);
|
||||||
|
X509_free(equipCertObj);
|
||||||
|
X509_free(equipCACertObj);
|
||||||
|
MS_LOG(INFO) << "verifyCAChain success.";
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyCertKeyID(const X509 *caCert, const X509 *subCert) const {
|
||||||
|
bool result = true;
|
||||||
|
ASN1_OCTET_STRING *skid = nullptr;
|
||||||
|
AUTHORITY_KEYID *akeyid = nullptr;
|
||||||
|
do {
|
||||||
|
int crit = 0;
|
||||||
|
skid = reinterpret_cast<ASN1_OCTET_STRING *>(X509_get_ext_d2i(caCert, NID_subject_key_identifier, &crit, NULL));
|
||||||
|
if (skid == nullptr) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
char subject_keyid[512] = {0};
|
||||||
|
for (int i = 0; i < skid->length; i++) {
|
||||||
|
char keyid[8] = {0};
|
||||||
|
int base = 512;
|
||||||
|
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)skid->data[i]);
|
||||||
|
int ret = strcat_s(subject_keyid, base, keyid);
|
||||||
|
if (ret == -1) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
akeyid = reinterpret_cast<AUTHORITY_KEYID *>(X509_get_ext_d2i(subCert, NID_authority_key_identifier, &crit, NULL));
|
||||||
|
if (akeyid == nullptr) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
char issuer_keyid[512] = {0};
|
||||||
|
if (akeyid->keyid == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "keyid is nullprt.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < akeyid->keyid->length; i++) {
|
||||||
|
char keyid[8] = {0};
|
||||||
|
int base = 512;
|
||||||
|
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)(akeyid->keyid->data[i]));
|
||||||
|
int ret = strcat_s(issuer_keyid, base, keyid);
|
||||||
|
if (ret == -1) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string subject_keyid_str = subject_keyid;
|
||||||
|
std::string issuer_keyid_str = issuer_keyid;
|
||||||
|
if (subject_keyid_str != issuer_keyid_str) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (0);
|
||||||
|
ASN1_OCTET_STRING_free(skid);
|
||||||
|
AUTHORITY_KEYID_free(akeyid);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyExtendedAttributes(const X509 *cert) const {
|
||||||
|
bool result = true;
|
||||||
|
BASIC_CONSTRAINTS *bcons = nullptr;
|
||||||
|
ASN1_BIT_STRING *lASN1UsageStr = nullptr;
|
||||||
|
do {
|
||||||
|
int cirt = 0;
|
||||||
|
bcons = reinterpret_cast<BASIC_CONSTRAINTS *>(X509_get_ext_d2i(cert, NID_basic_constraints, &cirt, NULL));
|
||||||
|
if (bcons == nullptr) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (!bcons->ca) {
|
||||||
|
MS_LOG(ERROR) << "Subject Type is End Entity.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Subject Type is CA.";
|
||||||
|
|
||||||
|
lASN1UsageStr = reinterpret_cast<ASN1_BIT_STRING *>(X509_get_ext_d2i(cert, NID_key_usage, NULL, NULL));
|
||||||
|
if (lASN1UsageStr == nullptr) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
int16_t usage = lASN1UsageStr->data[0];
|
||||||
|
if (lASN1UsageStr->length > 1) {
|
||||||
|
const unsigned int move = 8;
|
||||||
|
usage |= lASN1UsageStr->data[1] << move;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(usage & KU_KEY_CERT_SIGN)) {
|
||||||
|
MS_LOG(ERROR) << "Subject is not Certificate Signature.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Subject is Certificate Signature.";
|
||||||
|
} while (0);
|
||||||
|
BASIC_CONSTRAINTS_free(bcons);
|
||||||
|
ASN1_BIT_STRING_free(lASN1UsageStr);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyCertCommonName(const X509 *caCert, const X509 *subCert) const {
|
||||||
|
if (caCert == nullptr || subCert == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
char caSubjectCN[256] = "";
|
||||||
|
char subIssuerCN[256] = "";
|
||||||
|
|
||||||
|
X509_NAME *caSubjectX509CN = X509_get_subject_name(caCert);
|
||||||
|
X509_NAME *subIssuerX509CN = X509_get_issuer_name(subCert);
|
||||||
|
|
||||||
|
int ret = X509_NAME_get_text_by_NID(caSubjectX509CN, NID_commonName, caSubjectCN, sizeof(caSubjectCN));
|
||||||
|
if (ret < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ret = X509_NAME_get_text_by_NID(subIssuerX509CN, NID_commonName, subIssuerCN, sizeof(subIssuerCN));
|
||||||
|
if (ret < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string caSubjectCNStr = caSubjectCN;
|
||||||
|
std::string subIssuerCNStr = subIssuerCN;
|
||||||
|
|
||||||
|
if (caSubjectCNStr != subIssuerCNStr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyCRL(const std::string &equipCert, const std::string &equipCrlPath) {
|
||||||
|
if (!checkFileExists(equipCrlPath)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool result = true;
|
||||||
|
X509_CRL *equipCrl = nullptr;
|
||||||
|
X509 *equipCertObj = nullptr;
|
||||||
|
EVP_PKEY *evp_pkey = nullptr;
|
||||||
|
do {
|
||||||
|
equipCrl = CertVerify::readCrlFromFile(equipCrlPath);
|
||||||
|
equipCertObj = readCertFromPerm(equipCert);
|
||||||
|
if (equipCertObj == nullptr) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (equipCrl == nullptr) {
|
||||||
|
MS_LOG(INFO) << "equipCrl is nullptr. return true.";
|
||||||
|
result = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
evp_pkey = X509_get_pubkey(equipCertObj);
|
||||||
|
int ret = X509_CRL_verify(equipCrl, evp_pkey);
|
||||||
|
if (ret == 1) {
|
||||||
|
MS_LOG(ERROR) << "equip cert in equip crl, verify failed";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (0);
|
||||||
|
|
||||||
|
EVP_PKEY_free(evp_pkey);
|
||||||
|
X509_free(equipCertObj);
|
||||||
|
X509_CRL_free(equipCrl);
|
||||||
|
MS_LOG(INFO) << "verifyCRL success.";
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned char *signData, const std::string &flID,
|
||||||
|
const std::string &timeStamp) {
|
||||||
|
if (keyAttestation.empty() || signData == nullptr || flID.empty() || timeStamp.empty()) {
|
||||||
|
MS_LOG(ERROR) << "keyAttestation or signData or flID or timeStamp is empty.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool result = true;
|
||||||
|
X509 *keyAttestationCertObj = nullptr;
|
||||||
|
EVP_PKEY *pubKey = nullptr;
|
||||||
|
do {
|
||||||
|
keyAttestationCertObj = readCertFromPerm(keyAttestation);
|
||||||
|
|
||||||
|
std::string srcData = flID + " " + timeStamp;
|
||||||
|
// SHA256_DIGEST_LENGTH is 32
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
sha256Hash(srcData, srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
|
||||||
|
pubKey = X509_get_pubkey(keyAttestationCertObj);
|
||||||
|
RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey);
|
||||||
|
if (pRSAPublicKey == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "get rsa public key failed.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int pubKeyLen = RSA_size(pRSAPublicKey);
|
||||||
|
unsigned char buffer[256];
|
||||||
|
int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING);
|
||||||
|
if (ret == -1) {
|
||||||
|
MS_LOG(ERROR) << "rsa public decrypt failed.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int saltLen = -2;
|
||||||
|
ret = RSA_verify_PKCS1_PSS(pRSAPublicKey, srcDataHash, EVP_sha256(), buffer, saltLen);
|
||||||
|
if (ret != 1) {
|
||||||
|
int64_t ulErr = SizeToLong(ERR_get_error());
|
||||||
|
char szErrMsg[1024] = {0};
|
||||||
|
MS_LOG(ERROR) << "verify error. error number: " << ulErr;
|
||||||
|
std::string str_res = ERR_error_string(ulErr, szErrMsg);
|
||||||
|
MS_LOG(ERROR) << szErrMsg;
|
||||||
|
if (str_res.empty()) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (0);
|
||||||
|
EVP_PKEY_free(pubKey);
|
||||||
|
X509_free(keyAttestationCertObj);
|
||||||
|
CRYPTO_cleanup_all_ex_data();
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "verifyRSAKey success.";
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CertVerify::sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash, const int len) const {
|
||||||
|
if (len <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
SHA256_CTX sha_ctx;
|
||||||
|
int ret = SHA256_Init(&sha_ctx);
|
||||||
|
if (ret != 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ret = SHA256_Update(&sha_ctx, src, src_len);
|
||||||
|
if (ret != 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ret = SHA256_Final(hash, &sha_ctx);
|
||||||
|
if (ret != 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CertVerify::toHexString(const unsigned char *data, const int len) {
|
||||||
|
if (data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "data hash is null.";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (len <= 0) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
std::stringstream ss;
|
||||||
|
int base = 2;
|
||||||
|
for (int i = 0; i < len; i++) {
|
||||||
|
ss << std::hex << std::setw(base) << std::setfill('0') << static_cast<int>(data[i]);
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyEquipCertAndFlID(const std::string &flID, const std::string &equipCert) {
|
||||||
|
unsigned char hash[SHA256_DIGEST_LENGTH] = {""};
|
||||||
|
sha256Hash(equipCert, hash, SHA256_DIGEST_LENGTH);
|
||||||
|
std::string equipCertSha256 = toHexString(hash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (flID == equipCertSha256) {
|
||||||
|
MS_LOG(INFO) << "verifyEquipCertAndFlID success.";
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "verifyEquipCertAndFlID failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyTimeStamp(const std::string &flID, const std::string &timeStamp) const {
|
||||||
|
int64_t requestTime = std::stoll(timeStamp.c_str());
|
||||||
|
const int64_t base = 1000;
|
||||||
|
struct timeval tv {};
|
||||||
|
int ret = gettimeofday(&tv, nullptr);
|
||||||
|
if (ret != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
int64_t now = tv.tv_sec * base + tv.tv_usec / base;
|
||||||
|
MS_LOG(INFO) << "flID: " << flID.c_str() << ",now time: " << now << ",requestTime: " << requestTime;
|
||||||
|
|
||||||
|
int64_t diff = now - requestTime;
|
||||||
|
if (diff > replayAttackTimeDiff || diff < 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verifyTimeStamp success.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CertVerify::sha256Hash(const std::string &src, uint8_t *hash, const int len) const {
|
||||||
|
if (len <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
SHA256_CTX sha_ctx;
|
||||||
|
int ret = SHA256_Init(&sha_ctx);
|
||||||
|
if (ret != 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ret = SHA256_Update(&sha_ctx, src.c_str(), src.size());
|
||||||
|
if (ret != 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ret = SHA256_Final(hash, &sha_ctx);
|
||||||
|
if (ret != 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData,
|
||||||
|
int srcDataLen) {
|
||||||
|
if (keyAttestation.empty() || signData == nullptr || srcData == nullptr || srcDataLen <= 0) {
|
||||||
|
MS_LOG(ERROR) << "keyAttestation or signData or srcData is invalid.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool result = true;
|
||||||
|
X509 *keyAttestationCertObj = nullptr;
|
||||||
|
EVP_PKEY *pubKey = nullptr;
|
||||||
|
do {
|
||||||
|
keyAttestationCertObj = readCertFromPerm(keyAttestation);
|
||||||
|
pubKey = X509_get_pubkey(keyAttestationCertObj);
|
||||||
|
RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey);
|
||||||
|
if (pRSAPublicKey == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "get rsa public key failed.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int pubKeyLen = RSA_size(pRSAPublicKey);
|
||||||
|
unsigned char buffer[256];
|
||||||
|
int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING);
|
||||||
|
if (ret == -1) {
|
||||||
|
MS_LOG(ERROR) << "rsa public decrypt failed.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
int saltLen = -2;
|
||||||
|
ret = RSA_verify_PKCS1_PSS(pRSAPublicKey, srcData, EVP_sha256(), buffer, saltLen);
|
||||||
|
if (ret != 1) {
|
||||||
|
int64_t ulErr = SizeToLong(ERR_get_error());
|
||||||
|
char szErrMsg[1024] = {0};
|
||||||
|
MS_LOG(ERROR) << "verify error. error number: " << ulErr;
|
||||||
|
std::string str_res = ERR_error_string(ulErr, szErrMsg);
|
||||||
|
MS_LOG(ERROR) << szErrMsg;
|
||||||
|
if (str_res.empty()) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (0);
|
||||||
|
EVP_PKEY_free(pubKey);
|
||||||
|
X509_free(keyAttestationCertObj);
|
||||||
|
CRYPTO_cleanup_all_ex_data();
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "verifyRSAKey success.";
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath,
|
||||||
|
const std::string equipCrlPath, const uint64_t replay_attack_time_diff) {
|
||||||
|
if (rootFirstCaFilePath.empty() || rootSecondCaFilePath.empty()) {
|
||||||
|
MS_LOG(ERROR) << "the root or crl path is empty.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!checkFileExists(rootFirstCaFilePath)) {
|
||||||
|
MS_LOG(ERROR) << "The rootFirstCaFilePath is not exist.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!checkFileExists(rootSecondCaFilePath)) {
|
||||||
|
MS_LOG(ERROR) << "The rootSecondCaFilePath is not exist.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
replayAttackTimeDiff = UlongToLong(replay_attack_time_diff);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CertVerify::verifyCertAndSign(const std::string &flID, const std::string &timeStamp, const unsigned char *signData,
|
||||||
|
const std::string &keyAttestation, const std::string &equipCert,
|
||||||
|
const std::string &equipCACert, const std::string &rootFirstCAPath,
|
||||||
|
const std::string &rootSecondCAPath, const std::string &equipCrlPath) {
|
||||||
|
if (!verifyEquipCertAndFlID(flID, equipCert)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyCAChain(keyAttestation, equipCert, equipCACert, rootFirstCAPath, rootSecondCAPath)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyCRL(equipCert, equipCrlPath)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyRSAKey(keyAttestation, signData, flID, timeStamp)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!verifyTimeStamp(flID, timeStamp)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
bool CertVerify::verifyTimeStamp(const std::string &flID, const std::string &timeStamp) const {
|
||||||
|
MS_LOG(WARNING) << "verifyTimeStamp in win32 platform.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void CertVerify::sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash, const int len) const {
|
||||||
|
MS_LOG(WARNING) << "sha256Hash in win32 platform.";
|
||||||
|
}
|
||||||
|
bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData,
|
||||||
|
int srcDataLen) {
|
||||||
|
MS_LOG(WARNING) << "verifyRSAKey in win32 platform.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool CertVerify::initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath,
|
||||||
|
const std::string equipCrlPath, const uint64_t replay_attack_time_diff) {
|
||||||
|
MS_LOG(WARNING) << "initRootCertAndCRL in win32 platform.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
bool CertVerify::verifyCertAndSign(const std::string &flID, const std::string &timeStamp, const unsigned char *signData,
|
||||||
|
const std::string &keyAttestation, const std::string &equipCert,
|
||||||
|
const std::string &equipCACert, const std::string &rootFirstCAPath,
|
||||||
|
const std::string &rootSecondCAPath, const std::string &equipCrlPath) {
|
||||||
|
MS_LOG(WARNING) << "verifyCertAndSign in win32 platform.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,105 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_CERT_VERIFY_H
|
||||||
|
#define MINDSPORE_CCSRC_FL_SERVER_CERT_VERIFY_H
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#ifndef _WIN32
|
||||||
|
#include <openssl/evp.h>
|
||||||
|
#include <openssl/rsa.h>
|
||||||
|
#include <openssl/x509v3.h>
|
||||||
|
#include <openssl/err.h>
|
||||||
|
#include <openssl/pem.h>
|
||||||
|
#include <openssl/sha.h>
|
||||||
|
#endif
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "fl/server/common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ps {
|
||||||
|
namespace server {
|
||||||
|
class CertVerify {
|
||||||
|
public:
|
||||||
|
CertVerify() {}
|
||||||
|
~CertVerify() = default;
|
||||||
|
|
||||||
|
bool verifyCertAndSign(const std::string &flID, const std::string &timeStamp, const unsigned char *signData,
|
||||||
|
const std::string &keyAttestation, const std::string &equipCert,
|
||||||
|
const std::string &equipCACert, const std::string &rootFirstCAPath,
|
||||||
|
const std::string &rootSecondCAPath, const std::string &equipCrlPath);
|
||||||
|
|
||||||
|
static bool initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath,
|
||||||
|
const std::string equipCrlPath, uint64_t replay_attack_time_diff_);
|
||||||
|
|
||||||
|
// verify valid of sign data
|
||||||
|
bool verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData, int srcDataLen);
|
||||||
|
|
||||||
|
void sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash, const int len) const;
|
||||||
|
|
||||||
|
// verify valid of time stamp of request
|
||||||
|
bool verifyTimeStamp(const std::string &flID, const std::string &timeStamp) const;
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
|
||||||
|
private:
|
||||||
|
// read certificate from file path
|
||||||
|
static X509 *readCertFromFile(const std::string &certPath);
|
||||||
|
|
||||||
|
// read Certificate Revocation List from file absolute path
|
||||||
|
static X509_CRL *readCrlFromFile(const std::string &crlPath);
|
||||||
|
|
||||||
|
// read certificate from pem string
|
||||||
|
X509 *readCertFromPerm(std::string cert);
|
||||||
|
|
||||||
|
// verify valid of certificate time
|
||||||
|
bool verifyCertTime(const X509 *cert) const;
|
||||||
|
|
||||||
|
// verify valid of certificate chain
|
||||||
|
bool verifyCAChain(const std::string &keyAttestation, const std::string &equipCert, const std::string &equipCACert,
|
||||||
|
const std::string &rootFirstCAPath, const std::string &rootSecondCAPath);
|
||||||
|
|
||||||
|
// verify valid of sign data
|
||||||
|
bool verifyRSAKey(const std::string &keyAttestation, const unsigned char *signData, const std::string &flID,
|
||||||
|
const std::string &timeStamp);
|
||||||
|
|
||||||
|
// verify valid of equip certificate with CRL
|
||||||
|
bool verifyCRL(const std::string &equipCert, const std::string &equipCrlPath);
|
||||||
|
|
||||||
|
// verify valid of flID with sha256(equip cert)
|
||||||
|
bool verifyEquipCertAndFlID(const std::string &flID, const std::string &equipCert);
|
||||||
|
|
||||||
|
void sha256Hash(const std::string &src, uint8_t *hash, const int len) const;
|
||||||
|
|
||||||
|
std::string toHexString(const unsigned char *data, const int len);
|
||||||
|
|
||||||
|
bool verifyCertCommonName(const X509 *caCert, const X509 *subCert) const;
|
||||||
|
|
||||||
|
bool verifyExtendedAttributes(const X509 *cert) const;
|
||||||
|
|
||||||
|
bool verifyCertKeyID(const X509 *caCert, const X509 *subCert) const;
|
||||||
|
|
||||||
|
bool verifyPublicKey(const X509 *keyAttestationCertObj, const X509 *equipCertObj, const X509 *equipCACertObj,
|
||||||
|
const X509 *rootFirstCA, const X509 *rootSecondCA) const;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
} // namespace server
|
||||||
|
} // namespace ps
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FL_SERVER_CERT_VERIFY_H
|
|
@ -93,7 +93,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size
|
||||||
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 3: Reduce the data so we can overlap the time cost of send.
|
// Step 3: Reduce the data so we can overlap the time cost of send.
|
||||||
for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) {
|
for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) {
|
||||||
recv_chunk[j] += tmp_recv_chunk[j];
|
recv_chunk[j] += tmp_recv_chunk[j];
|
||||||
|
@ -164,9 +163,9 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
||||||
for (uint32_t i = 1; i < rank_size; i++) {
|
for (uint32_t i = 1; i < rank_size; i++) {
|
||||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
|
MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i;
|
||||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
|
auto recv_req_id1 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str);
|
||||||
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
|
if (!server_node_->CollectiveWait(recv_req_id1, kCollectiveCommTimeout)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id1 << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ret = memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size());
|
ret = memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size());
|
||||||
|
@ -180,9 +179,9 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
|
MS_LOG(DEBUG) << "Reduce send data to rank 0 process.";
|
||||||
auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
|
auto send_req_id1 = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T));
|
||||||
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
|
if (!server_node_->Wait(send_req_id1, kCollectiveCommTimeout)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << send_req_id1 << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -193,19 +192,19 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
||||||
if (rank_id_ == 0) {
|
if (rank_id_ == 0) {
|
||||||
for (uint32_t i = 1; i < rank_size; i++) {
|
for (uint32_t i = 1; i < rank_size; i++) {
|
||||||
MS_LOG(DEBUG) << "Broadcast data to process " << i;
|
MS_LOG(DEBUG) << "Broadcast data to process " << i;
|
||||||
auto send_req_id =
|
auto send_req_id2 =
|
||||||
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
|
server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T));
|
||||||
if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) {
|
if (!server_node_->Wait(send_req_id2, kCollectiveCommTimeout)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << send_req_id2 << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
|
MS_LOG(DEBUG) << "Broadcast receive from rank 0.";
|
||||||
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
std::shared_ptr<std::vector<unsigned char>> recv_str;
|
||||||
auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
|
auto recv_req_id2 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str);
|
||||||
if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) {
|
if (!server_node_->CollectiveWait(recv_req_id2, kCollectiveCommTimeout)) {
|
||||||
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed.";
|
MS_LOG(ERROR) << "CollectiveWait " << recv_req_id2 << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ret = memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
|
ret = memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size());
|
||||||
|
@ -219,7 +218,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count) {
|
bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff, size_t send_count) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(node_, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(node_, false);
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false);
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false);
|
||||||
|
@ -230,8 +229,8 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size
|
||||||
// Store offsets to get every data chunk's address.
|
// Store offsets to get every data chunk's address.
|
||||||
std::vector<size_t> chunk_offset;
|
std::vector<size_t> chunk_offset;
|
||||||
for (size_t i = 0; i < rank_size_; i++) {
|
for (size_t i = 0; i < rank_size_; i++) {
|
||||||
size_t ofs =
|
size_t ofs = std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + SizeToLong(i), static_cast<size_t>(0),
|
||||||
std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast<size_t>(0), std::plus<size_t>());
|
std::plus<size_t>());
|
||||||
chunk_offset.push_back(ofs);
|
chunk_offset.push_back(ofs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,7 +294,7 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c
|
||||||
MS_LOG(ERROR) << "The group is empty.";
|
MS_LOG(ERROR) << "The group is empty.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
uint32_t group_rank_size = group_info.group_ranks.size();
|
uint32_t group_rank_size = SizeToUint(group_info.group_ranks.size());
|
||||||
uint32_t global_root_rank = group_to_global_ranks[root];
|
uint32_t global_root_rank = group_to_global_ranks[root];
|
||||||
|
|
||||||
// Broadcast data to processes which are not the root.
|
// Broadcast data to processes which are not the root.
|
||||||
|
@ -349,7 +348,7 @@ bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t c
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t send_count,
|
bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *const recvbuff, size_t send_count,
|
||||||
const std::shared_ptr<ps::core::AbstractNode> &node) {
|
const std::shared_ptr<ps::core::AbstractNode> &node) {
|
||||||
std::unique_lock<std::mutex> lock(mtx_);
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(node, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(node, false);
|
||||||
|
@ -362,10 +361,10 @@ bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t s
|
||||||
rank_id_ = node_->rank_id();
|
rank_id_ = node_->rank_id();
|
||||||
switch (node_role_) {
|
switch (node_role_) {
|
||||||
case ps::core::WORKER:
|
case ps::core::WORKER:
|
||||||
rank_size_ = node_->worker_num();
|
rank_size_ = IntToUint(node_->worker_num());
|
||||||
break;
|
break;
|
||||||
case ps::core::SERVER:
|
case ps::core::SERVER:
|
||||||
rank_size_ = node_->server_num();
|
rank_size_ = IntToUint(node_->server_num());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(ERROR) << "The node role " << node_role_ << " for collective communication is invalid.";
|
MS_LOG(ERROR) << "The node role " << node_role_ << " for collective communication is invalid.";
|
||||||
|
@ -380,7 +379,7 @@ bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t s
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root,
|
bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *const recvbuff, size_t count, uint32_t root,
|
||||||
const std::shared_ptr<ps::core::AbstractNode> &node,
|
const std::shared_ptr<ps::core::AbstractNode> &node,
|
||||||
const CommunicationGroupInfo &group_info) {
|
const CommunicationGroupInfo &group_info) {
|
||||||
std::unique_lock<std::mutex> lock(mtx_);
|
std::unique_lock<std::mutex> lock(mtx_);
|
||||||
|
|
|
@ -45,17 +45,17 @@ enum CommType { HTTP = 0, TCP };
|
||||||
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
|
enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum };
|
||||||
|
|
||||||
struct RoundConfig {
|
struct RoundConfig {
|
||||||
// The name of round. Please refer to round kernel *.cc files.
|
// The name of the round. Please refer to round kernel *.cc files.
|
||||||
std::string name;
|
std::string name;
|
||||||
// Whether this round has the time window limit.
|
// Whether this round has the time window limit.
|
||||||
bool check_timeout = false;
|
bool check_timeout = false;
|
||||||
// The length of the time window. Only used when check_timeout is set to true.
|
// The length of the time window. Only used when check_timeout is set to true.
|
||||||
size_t time_window = 3000;
|
size_t time_window = 3000;
|
||||||
// Whether this round has to check the request count has reached the threshold.
|
// Whether this round has to check the request count has reach the threshold.
|
||||||
bool check_count = false;
|
bool check_count = false;
|
||||||
// This round's request threshold count. Only used when check_count is set to true.
|
// This round's request threshold count. Only used when threshold_count is set to true.
|
||||||
size_t threshold_count = 0;
|
size_t threshold_count = 0;
|
||||||
// Whether this round uses the server as threshold count. This is vital for some rounds in elastic scaling scenario.
|
// Whether this round uses the server number as threshold. This is vital for some rounds in elastic scaling scenario.
|
||||||
bool server_num_as_threshold = false;
|
bool server_num_as_threshold = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -67,6 +67,8 @@ struct CipherConfig {
|
||||||
size_t share_secrets_threshold = 0;
|
size_t share_secrets_threshold = 0;
|
||||||
size_t get_secrets_threshold = 0;
|
size_t get_secrets_threshold = 0;
|
||||||
size_t client_list_threshold = 0;
|
size_t client_list_threshold = 0;
|
||||||
|
size_t push_list_sign_threshold = 0;
|
||||||
|
size_t get_list_sign_threshold = 0;
|
||||||
size_t reconstruct_secrets_threshold = 0;
|
size_t reconstruct_secrets_threshold = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -130,7 +132,6 @@ constexpr auto kAdamEps = "eps";
|
||||||
constexpr auto kFtrlLinear = "linear";
|
constexpr auto kFtrlLinear = "linear";
|
||||||
constexpr auto kDataSize = "data_size";
|
constexpr auto kDataSize = "data_size";
|
||||||
constexpr auto kNewDataSize = "new_data_size";
|
constexpr auto kNewDataSize = "new_data_size";
|
||||||
constexpr auto kStat = "stat";
|
|
||||||
|
|
||||||
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
// OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is
|
||||||
// launched.
|
// launched.
|
||||||
|
@ -175,14 +176,11 @@ const OptimParamNameToIndex kAdamWeightDecayNameToIdx = {{"inputs",
|
||||||
{"weight_decay", 7},
|
{"weight_decay", 7},
|
||||||
{"grad", 8}}},
|
{"grad", 8}}},
|
||||||
{"outputs", {}}};
|
{"outputs", {}}};
|
||||||
const OptimParamNameToIndex kSGDNameToIdx = {
|
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {{kApplyMomentumOpName, kMomentumNameToIdx},
|
||||||
{"inputs", {{kWeight, 0}, {kGradient, 1}, {kLearningRate, 2}, {kAccumulation, 3}, {kMomentum, 4}, {kStat, 5}}},
|
{kFusedSparseAdamName, kSparseAdamNameToIdx},
|
||||||
{"outputs", {}}};
|
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx},
|
||||||
|
{kApplyAdamOpName, kAdamNameToIdx},
|
||||||
const std::map<std::string, OptimParamNameToIndex> kNameToIdxMap = {
|
{"AdamWeightDecay", kAdamWeightDecayNameToIdx}};
|
||||||
{kApplyMomentumOpName, kMomentumNameToIdx}, {kFusedSparseAdamName, kSparseAdamNameToIdx},
|
|
||||||
{kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, {kApplyAdamOpName, kAdamNameToIdx},
|
|
||||||
{"AdamWeightDecay", kAdamWeightDecayNameToIdx}, {kSGDName, kSGDNameToIdx}};
|
|
||||||
|
|
||||||
constexpr uint32_t kLeaderServerRank = 0;
|
constexpr uint32_t kLeaderServerRank = 0;
|
||||||
constexpr size_t kWorkerMgrThreadPoolSize = 32;
|
constexpr size_t kWorkerMgrThreadPoolSize = 32;
|
||||||
|
@ -215,9 +213,12 @@ constexpr auto kCtxGetSecretsClientList = "get_secrets_client_list";
|
||||||
constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
|
constexpr auto kCtxReconstructClientList = "reconstruct_client_list";
|
||||||
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
|
constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list";
|
||||||
constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list";
|
constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list";
|
||||||
|
constexpr auto kCtxClientListSigns = "client_list_signs";
|
||||||
|
constexpr auto kCtxClientKeyAttestation = "client_key_attestation";
|
||||||
constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
|
constexpr auto kCtxGetKeysClientList = "get_keys_client_list";
|
||||||
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
|
constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size";
|
||||||
constexpr auto kCtxCipherPrimer = "cipher_primer";
|
constexpr auto kCtxCipherPrimer = "cipher_primer";
|
||||||
|
constexpr auto kCurrentIteration = "current_iteration";
|
||||||
|
|
||||||
// This macro the current timestamp in milliseconds.
|
// This macro the current timestamp in milliseconds.
|
||||||
#define CURRENT_TIME_MILLI \
|
#define CURRENT_TIME_MILLI \
|
||||||
|
@ -252,6 +253,14 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size
|
||||||
return addr;
|
return addr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) {
|
||||||
|
if (!json.contains(key)) {
|
||||||
|
MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump();
|
||||||
|
}
|
||||||
|
return json[key].get<T>();
|
||||||
|
}
|
||||||
|
|
||||||
// Definitions for Federated Learning.
|
// Definitions for Federated Learning.
|
||||||
|
|
||||||
constexpr auto kNetworkError = "Cluster networking failed.";
|
constexpr auto kNetworkError = "Cluster networking failed.";
|
||||||
|
|
|
@ -26,7 +26,7 @@ bool ConsistentHashRing::Insert(uint32_t rank) {
|
||||||
MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank << ", hash value is "
|
MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank << ", hash value is "
|
||||||
<< hash_value;
|
<< hash_value;
|
||||||
if (ring_.count(hash_value) != 0) {
|
if (ring_.count(hash_value) != 0) {
|
||||||
MS_LOG(INFO) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring.";
|
MS_LOG(WARNING) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring.";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ring_[hash_value] = rank;
|
ring_[hash_value] = rank;
|
||||||
|
@ -37,7 +37,7 @@ bool ConsistentHashRing::Insert(uint32_t rank) {
|
||||||
bool ConsistentHashRing::Erase(uint32_t rank) {
|
bool ConsistentHashRing::Erase(uint32_t rank) {
|
||||||
for (auto iterator = ring_.begin(); iterator != ring_.end();) {
|
for (auto iterator = ring_.begin(); iterator != ring_.end();) {
|
||||||
if (iterator->second == rank) {
|
if (iterator->second == rank) {
|
||||||
iterator = ring_.erase(iterator);
|
(void)ring_.erase(iterator++);
|
||||||
} else {
|
} else {
|
||||||
++iterator;
|
++iterator;
|
||||||
}
|
}
|
||||||
|
|
|
@ -329,6 +329,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
|
||||||
|
|
||||||
// Broadcast to all follower servers.
|
// Broadcast to all follower servers.
|
||||||
for (uint32_t i = 1; i < server_num_; i++) {
|
for (uint32_t i = 1; i < server_num_; i++) {
|
||||||
|
MS_LOG(INFO) << "Start sending first count event message to server " << i;
|
||||||
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
||||||
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
|
MS_LOG(ERROR) << "Activating first count event to server " << i << " failed.";
|
||||||
if (reason != nullptr) {
|
if (reason != nullptr) {
|
||||||
|
@ -343,7 +344,9 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Leader server directly calls the callback.
|
// Leader server directly calls the callback.
|
||||||
|
MS_LOG(INFO) << "Leader server call first count handler for " << name << "...";
|
||||||
counter_handlers_[name].first_count_handler(nullptr);
|
counter_handlers_[name].first_count_handler(nullptr);
|
||||||
|
MS_LOG(INFO) << "First count handler for " << name << " is successfully called.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -355,6 +358,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
|
||||||
|
|
||||||
// Broadcast to all follower servers.
|
// Broadcast to all follower servers.
|
||||||
for (uint32_t i = 1; i < server_num_; i++) {
|
for (uint32_t i = 1; i < server_num_; i++) {
|
||||||
|
MS_LOG(INFO) << "Start sending last count event message to server " << i;
|
||||||
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) {
|
||||||
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
|
MS_LOG(ERROR) << "Activating last count event to server " << i << " failed.";
|
||||||
if (reason != nullptr) {
|
if (reason != nullptr) {
|
||||||
|
@ -369,7 +373,9 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Leader server directly calls the callback.
|
// Leader server directly calls the callback.
|
||||||
|
MS_LOG(INFO) << "Leader server call last count handler for " << name << "...";
|
||||||
counter_handlers_[name].last_count_handler(nullptr);
|
counter_handlers_[name].last_count_handler(nullptr);
|
||||||
|
MS_LOG(INFO) << "Last count handler for " << name << " is successfully called.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "utils/hash_map.h"
|
|
||||||
#include "proto/ps.pb.h"
|
#include "proto/ps.pb.h"
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
#include "ps/core/server_node.h"
|
#include "ps/core/server_node.h"
|
||||||
|
@ -118,14 +117,14 @@ class DistributedCountService {
|
||||||
|
|
||||||
// Key: name, e.g, startFLJob, updateModel, push.
|
// Key: name, e.g, startFLJob, updateModel, push.
|
||||||
// Value: a set of id without repeatation because each work may report multiple times.
|
// Value: a set of id without repeatation because each work may report multiple times.
|
||||||
mindspore::HashMap<std::string, std::set<std::string>> global_current_count_;
|
std::unordered_map<std::string, std::set<std::string>> global_current_count_;
|
||||||
|
|
||||||
// Key: name, e.g, StartFLJobCount.
|
// Key: name, e.g, StartFLJobCount.
|
||||||
// Value: global threshold count in the server cluster dimension for this name.
|
// Value: global threshold count in the server cluster dimension for this name.
|
||||||
mindspore::HashMap<std::string, size_t> global_threshold_count_;
|
std::unordered_map<std::string, size_t> global_threshold_count_;
|
||||||
|
|
||||||
// First/last count event callbacks of the name.
|
// First/last count event callbacks of the name.
|
||||||
mindspore::HashMap<std::string, CounterHandlers> counter_handlers_;
|
std::unordered_map<std::string, CounterHandlers> counter_handlers_;
|
||||||
|
|
||||||
// Because the count is increased/queried conccurently, we must ensure the operations are threadsafe.
|
// Because the count is increased/queried conccurently, we must ensure the operations are threadsafe.
|
||||||
std::unordered_map<std::string, std::mutex> mutex_;
|
std::unordered_map<std::string, std::mutex> mutex_;
|
||||||
|
|
|
@ -125,10 +125,6 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) {
|
||||||
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
MS_LOG(ERROR) << "The consistent hash ring is not initialized yet.";
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
if (metadata_.count(name) == 0) {
|
|
||||||
MS_LOG(ERROR) << "The metadata of " << name << " is not registered.";
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
uint32_t stored_rank = router_->Find(name);
|
uint32_t stored_rank = router_->Find(name);
|
||||||
MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
|
MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank;
|
||||||
if (local_rank_ == stored_rank) {
|
if (local_rank_ == stored_rank) {
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include "utils/hash_map.h"
|
|
||||||
#include "proto/ps.pb.h"
|
#include "proto/ps.pb.h"
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
#include "ps/core/server_node.h"
|
#include "ps/core/server_node.h"
|
||||||
|
@ -107,7 +106,7 @@ class DistributedMetadataStore {
|
||||||
// We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use.
|
// We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use.
|
||||||
// Key: data name.
|
// Key: data name.
|
||||||
// Value: ProtoBuffer Struct.
|
// Value: ProtoBuffer Struct.
|
||||||
mindspore::HashMap<std::string, PBMetadata> metadata_;
|
std::unordered_map<std::string, PBMetadata> metadata_;
|
||||||
|
|
||||||
// Because the metadata is read/written conccurently, we must ensure the operations are threadsafe.
|
// Because the metadata is read/written conccurently, we must ensure the operations are threadsafe.
|
||||||
std::unordered_map<std::string, std::mutex> mutex_;
|
std::unordered_map<std::string, std::mutex> mutex_;
|
||||||
|
|
|
@ -65,47 +65,6 @@ bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) {
|
||||||
|
|
||||||
bool Executor::initialized() const { return initialized_; }
|
bool Executor::initialized() const { return initialized_; }
|
||||||
|
|
||||||
bool Executor::HandlePush(const std::string ¶m_name, const UploadData &upload_data) {
|
|
||||||
MS_LOG(DEBUG) << "Do Push for parameter " << param_name;
|
|
||||||
if (param_aggrs_.count(param_name) == 0) {
|
|
||||||
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::mutex &mtx = parameter_mutex_[param_name];
|
|
||||||
std::unique_lock<std::mutex> lock(mtx);
|
|
||||||
auto ¶m_aggr = param_aggrs_[param_name];
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
|
|
||||||
// Push operation needs to wait until the pulling process is done.
|
|
||||||
while (!param_aggr->IsPullingDone()) {
|
|
||||||
lock.unlock();
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
|
|
||||||
lock.lock();
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1.Update data with the uploaded data of the worker.
|
|
||||||
if (!param_aggr->UpdateData(upload_data)) {
|
|
||||||
MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// 2.Launch aggregation for this trainable parameter.
|
|
||||||
if (!param_aggr->LaunchAggregators()) {
|
|
||||||
MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (param_aggr->IsAggregationDone()) {
|
|
||||||
// 3.After the aggregation is done, optimize the trainable parameter.
|
|
||||||
if (!param_aggr->LaunchOptimizers()) {
|
|
||||||
MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// 4.Reset pulling and aggregation status after optimizing is done.
|
|
||||||
param_aggr->ResetPullingStatus();
|
|
||||||
param_aggr->ResetAggregationStatus();
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) {
|
bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) {
|
||||||
MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
|
MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name;
|
||||||
if (param_aggrs_.count(param_name) == 0) {
|
if (param_aggrs_.count(param_name) == 0) {
|
||||||
|
@ -131,32 +90,6 @@ bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Executor::HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map) {
|
|
||||||
std::unique_lock<std::mutex> model_lock(model_mutex_);
|
|
||||||
for (const auto &trainable_param : feature_map) {
|
|
||||||
const std::string ¶m_name = trainable_param.first;
|
|
||||||
if (param_aggrs_.count(param_name) == 0) {
|
|
||||||
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::mutex &mtx = parameter_mutex_[param_name];
|
|
||||||
std::unique_lock<std::mutex> lock(mtx);
|
|
||||||
auto ¶m_aggr = param_aggrs_[param_name];
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false);
|
|
||||||
const UploadData &upload_data = trainable_param.second;
|
|
||||||
if (!param_aggr->UpdateData(upload_data)) {
|
|
||||||
MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!param_aggr->LaunchAggregators()) {
|
|
||||||
MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
|
bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_map) {
|
||||||
for (const auto &trainable_param : feature_map) {
|
for (const auto &trainable_param : feature_map) {
|
||||||
const std::string ¶m_name = trainable_param.first;
|
const std::string ¶m_name = trainable_param.first;
|
||||||
|
@ -183,31 +116,6 @@ bool Executor::HandlePushWeight(const std::map<std::string, Address> &feature_ma
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
AddressPtr Executor::HandlePull(const std::string ¶m_name) {
|
|
||||||
MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name;
|
|
||||||
if (param_aggrs_.count(param_name) == 0) {
|
|
||||||
MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::mutex &mtx = parameter_mutex_[param_name];
|
|
||||||
std::unique_lock<std::mutex> lock(mtx);
|
|
||||||
auto ¶m_aggr = param_aggrs_[param_name];
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, nullptr);
|
|
||||||
// Pulling must wait until the optimizing process is done.
|
|
||||||
while (!param_aggr->IsOptimizingDone()) {
|
|
||||||
lock.unlock();
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime));
|
|
||||||
lock.lock();
|
|
||||||
}
|
|
||||||
AddressPtr addr = param_aggr->Pull();
|
|
||||||
// If this Pull is the last one, reset pulling and optimizing status.
|
|
||||||
if (param_aggr->IsPullingDone()) {
|
|
||||||
param_aggr->ResetOptimizingStatus();
|
|
||||||
}
|
|
||||||
return addr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> ¶m_names) {
|
std::map<std::string, AddressPtr> Executor::HandlePullWeight(const std::vector<std::string> ¶m_names) {
|
||||||
std::map<std::string, AddressPtr> weights;
|
std::map<std::string, AddressPtr> weights;
|
||||||
for (const auto ¶m_name : param_names) {
|
for (const auto ¶m_name : param_names) {
|
||||||
|
@ -298,7 +206,7 @@ bool Executor::unmasked() const {
|
||||||
if (encrypt_type == ps::kPWEncryptType) {
|
if (encrypt_type == ps::kPWEncryptType) {
|
||||||
return unmasked_.load();
|
return unmasked_.load();
|
||||||
} else {
|
} else {
|
||||||
// If the algorithm of pairwise encrypt is not enabled, consider_ unmasked flag as true.
|
// If the algorithm of mind armour is not enabled, consider unmasked_ flag as true.
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -340,7 +248,7 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) {
|
||||||
param_aggrs_[param_name] = param_aggr;
|
param_aggrs_[param_name] = param_aggr;
|
||||||
parameter_mutex_[param_name];
|
parameter_mutex_[param_name];
|
||||||
if (!param_aggr->Init(cnode, aggregation_count_)) {
|
if (!param_aggr->Init(cnode, aggregation_count_)) {
|
||||||
MS_LOG(EXCEPTION) << "Initializing parameter aggregator for " << param_name << " failed.";
|
MS_LOG(EXCEPTION) << "Initializing parameter aggregator for param_name " << param_name << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
|
MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success.";
|
||||||
|
|
|
@ -24,11 +24,11 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include "fl/server/common.h"
|
|
||||||
#include "fl/server/parameter_aggregator.h"
|
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
#include "fl/armour/cipher/cipher_unmask.h"
|
#include "fl/armour/cipher/cipher_unmask.h"
|
||||||
#endif
|
#endif
|
||||||
|
#include "fl/server/common.h"
|
||||||
|
#include "fl/server/parameter_aggregator.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -54,28 +54,13 @@ class Executor {
|
||||||
// After hyper-parameters are updated, some parameter aggregators should be reinitialized.
|
// After hyper-parameters are updated, some parameter aggregators should be reinitialized.
|
||||||
bool ReInitForUpdatingHyperParams(size_t aggr_threshold);
|
bool ReInitForUpdatingHyperParams(size_t aggr_threshold);
|
||||||
|
|
||||||
// Called in parameter server training mode to do Push operation.
|
|
||||||
// For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered
|
|
||||||
// as completed.
|
|
||||||
bool HandlePush(const std::string ¶m_name, const UploadData &upload_data);
|
|
||||||
|
|
||||||
// Called in parameter server training mode to do Pull operation.
|
|
||||||
// Returns the value of parameter param_name.
|
|
||||||
// HandlePull method must be called the same times as HandlePush is called before it's considered as
|
|
||||||
// completed.
|
|
||||||
AddressPtr HandlePull(const std::string ¶m_name);
|
|
||||||
|
|
||||||
// Called in federated learning training mode. Update value for parameter param_name.
|
// Called in federated learning training mode. Update value for parameter param_name.
|
||||||
bool HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data);
|
bool HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data);
|
||||||
|
|
||||||
// Called in asynchronous federated learning training mode. Update current model with the new feature map
|
// Forcibly overwrite specific weights in overwriteWeights message.
|
||||||
// asynchronously.
|
|
||||||
bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map);
|
|
||||||
|
|
||||||
// Overwrite the weights in server using pushed feature map.
|
|
||||||
bool HandlePushWeight(const std::map<std::string, Address> &feature_map);
|
bool HandlePushWeight(const std::map<std::string, Address> &feature_map);
|
||||||
|
|
||||||
// Returns multiple trainable parameters passed by weight_names.
|
// Returns value for multiple trainable parameters passed by weight_names.
|
||||||
std::map<std::string, AddressPtr> HandlePullWeight(const std::vector<std::string> ¶m_names);
|
std::map<std::string, AddressPtr> HandlePullWeight(const std::vector<std::string> ¶m_names);
|
||||||
|
|
||||||
// Reset the aggregation status for all aggregation kernels in the server.
|
// Reset the aggregation status for all aggregation kernels in the server.
|
||||||
|
@ -135,7 +120,7 @@ class Executor {
|
||||||
armour::CipherUnmask cipher_unmask_;
|
armour::CipherUnmask cipher_unmask_;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// The flag represents the unmasking status.
|
// The flag refers to the unmasking status
|
||||||
std::atomic<bool> unmasked_;
|
std::atomic<bool> unmasked_;
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
|
|
||||||
#include "fl/server/iteration.h"
|
#include "fl/server/iteration.h"
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include "fl/server/model_store.h"
|
#include "fl/server/model_store.h"
|
||||||
#include "fl/server/server.h"
|
#include "fl/server/server.h"
|
||||||
|
@ -38,6 +38,7 @@ Iteration::~Iteration() {
|
||||||
void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator) {
|
||||||
MS_EXCEPTION_IF_NULL(communicator);
|
MS_EXCEPTION_IF_NULL(communicator);
|
||||||
communicator_ = communicator;
|
communicator_ = communicator;
|
||||||
|
MS_EXCEPTION_IF_NULL(communicator_);
|
||||||
communicator_->RegisterMsgCallBack("syncIteration",
|
communicator_->RegisterMsgCallBack("syncIteration",
|
||||||
std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
|
std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1));
|
||||||
communicator_->RegisterMsgCallBack(
|
communicator_->RegisterMsgCallBack(
|
||||||
|
@ -54,10 +55,10 @@ void Iteration::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommu
|
||||||
void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) {
|
void Iteration::RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node) {
|
||||||
MS_EXCEPTION_IF_NULL(server_node);
|
MS_EXCEPTION_IF_NULL(server_node);
|
||||||
server_node_ = server_node;
|
server_node_ = server_node;
|
||||||
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
|
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::UserDefineEvent::kIterationRunning),
|
||||||
std::bind(&Iteration::HandleIterationRunningEvent, this));
|
std::bind(&Iteration::ProcessIterationRunningEvent, this));
|
||||||
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
|
server_node->RegisterCustomEventCallback(static_cast<uint32_t>(ps::UserDefineEvent::kIterationCompleted),
|
||||||
std::bind(&Iteration::HandleIterationCompletedEvent, this));
|
std::bind(&Iteration::ProcessIterationEndEvent, this));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
|
void Iteration::AddRound(const std::shared_ptr<Round> &round) {
|
||||||
|
@ -97,6 +98,7 @@ void Iteration::InitRounds(const std::vector<std::shared_ptr<ps::core::Communica
|
||||||
if (!move_to_next_thread_running_.load()) {
|
if (!move_to_next_thread_running_.load()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
lock.unlock();
|
||||||
MoveToNextIteration(is_last_iteration_valid_, move_to_next_reason_);
|
MoveToNextIteration(is_last_iteration_valid_, move_to_next_reason_);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -113,6 +115,7 @@ void Iteration::NotifyNext(bool is_last_iter_valid, const std::string &reason) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
|
void Iteration::MoveToNextIteration(bool is_last_iter_valid, const std::string &reason) {
|
||||||
|
iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
MS_LOG(INFO) << "Notify cluster starts to proceed to next iteration. Iteration is " << iteration_num_
|
MS_LOG(INFO) << "Notify cluster starts to proceed to next iteration. Iteration is " << iteration_num_
|
||||||
<< " validation is " << is_last_iter_valid << ". Reason: " << reason;
|
<< " validation is " << is_last_iter_valid << ". Reason: " << reason;
|
||||||
if (IsMoveToNextIterRequestReentrant(iteration_num_)) {
|
if (IsMoveToNextIterRequestReentrant(iteration_num_)) {
|
||||||
|
@ -147,7 +150,13 @@ void Iteration::SetIterationRunning() {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
|
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
|
||||||
if (server_node_->rank_id() == kLeaderServerRank) {
|
if (server_node_->rank_id() == kLeaderServerRank) {
|
||||||
// This event helps worker/server to be consistent in iteration state.
|
// This event helps worker/server to be consistent in iteration state.
|
||||||
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning));
|
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::UserDefineEvent::kIterationRunning));
|
||||||
|
if (server_recovery_ != nullptr) {
|
||||||
|
// Save data to the persistent storage in case the recovery happens at the beginning.
|
||||||
|
if (!server_recovery_->Save(iteration_num_)) {
|
||||||
|
MS_LOG(WARNING) << "Save recovery data failed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
|
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
|
||||||
|
@ -155,12 +164,12 @@ void Iteration::SetIterationRunning() {
|
||||||
start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
|
start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Iteration::SetIterationCompleted() {
|
void Iteration::SetIterationEnd() {
|
||||||
MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes.";
|
MS_LOG(INFO) << "Iteration " << iteration_num_ << " ends.";
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
|
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
|
||||||
if (server_node_->rank_id() == kLeaderServerRank) {
|
if (server_node_->rank_id() == kLeaderServerRank) {
|
||||||
// This event helps worker/server to be consistent in iteration state.
|
// This event helps worker/server to be consistent in iteration state.
|
||||||
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted));
|
server_node_->BroadcastEvent(static_cast<uint32_t>(ps::UserDefineEvent::kIterationCompleted));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
|
std::unique_lock<std::mutex> lock(iteration_state_mtx_);
|
||||||
|
@ -274,7 +283,7 @@ bool Iteration::DisableServerInstance(std::string *result) {
|
||||||
instance_state_ = InstanceState::kDisable;
|
instance_state_ = InstanceState::kDisable;
|
||||||
if (!ForciblyMoveToNextIteration()) {
|
if (!ForciblyMoveToNextIteration()) {
|
||||||
*result = "Disabling instance failed. Can't drop current iteration and move to the next.";
|
*result = "Disabling instance failed. Can't drop current iteration and move to the next.";
|
||||||
MS_LOG(ERROR) << *result;
|
MS_LOG(ERROR) << result;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
*result = "Disabling FL-Server succeeded.";
|
*result = "Disabling FL-Server succeeded.";
|
||||||
|
@ -295,6 +304,11 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (iteration_num_ == 1) {
|
||||||
|
MS_LOG(INFO) << "This is just the first iteration.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// Start new server instance.
|
// Start new server instance.
|
||||||
is_instance_being_updated_ = true;
|
is_instance_being_updated_ = true;
|
||||||
|
|
||||||
|
@ -312,7 +326,7 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string
|
||||||
ModelStore::GetInstance().Reset();
|
ModelStore::GetInstance().Reset();
|
||||||
if (metrics_ != nullptr) {
|
if (metrics_ != nullptr) {
|
||||||
if (!metrics_->Clear()) {
|
if (!metrics_->Clear()) {
|
||||||
MS_LOG(WARNING) << "Clear metrics fil failed.";
|
MS_LOG(WARNING) << "Clear metrics file failed.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -340,6 +354,16 @@ void Iteration::WaitAllRoundsFinish() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Iteration::set_recovery_handler(const std::shared_ptr<ServerRecovery> &server_recovery) {
|
||||||
|
MS_EXCEPTION_IF_NULL(server_recovery);
|
||||||
|
server_recovery_ = server_recovery;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Iteration::SyncAfterRecovery(uint64_t) {
|
||||||
|
NotifyNext(false, "Move to next iteration after recovery.");
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool Iteration::SyncIteration(uint32_t rank) {
|
bool Iteration::SyncIteration(uint32_t rank) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false);
|
||||||
SyncIterationRequest sync_iter_req;
|
SyncIterationRequest sync_iter_req;
|
||||||
|
@ -348,7 +372,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
||||||
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
|
std::shared_ptr<std::vector<unsigned char>> sync_iter_rsp_msg = nullptr;
|
||||||
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration,
|
if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration,
|
||||||
&sync_iter_rsp_msg)) {
|
&sync_iter_rsp_msg)) {
|
||||||
MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed.";
|
MS_LOG(ERROR) << "Sending sync iter message to leader server failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -356,8 +380,7 @@ bool Iteration::SyncIteration(uint32_t rank) {
|
||||||
SyncIterationResponse sync_iter_rsp;
|
SyncIterationResponse sync_iter_rsp;
|
||||||
(void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size()));
|
(void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size()));
|
||||||
iteration_num_ = sync_iter_rsp.iteration();
|
iteration_num_ = sync_iter_rsp.iteration();
|
||||||
MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is "
|
MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is " << iteration_num_;
|
||||||
<< sync_iter_rsp.iteration();
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -496,7 +519,9 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::
|
||||||
void Iteration::PrepareForNextIter() {
|
void Iteration::PrepareForNextIter() {
|
||||||
MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode.";
|
MS_LOG(INFO) << "Prepare for next iteration. Switch the server to safemode.";
|
||||||
Server::GetInstance().SwitchToSafeMode();
|
Server::GetInstance().SwitchToSafeMode();
|
||||||
|
MS_LOG(INFO) << "Start waiting for rounds to finish.";
|
||||||
WaitAllRoundsFinish();
|
WaitAllRoundsFinish();
|
||||||
|
MS_LOG(INFO) << "End waiting for rounds to finish.";
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
|
bool Iteration::BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason) {
|
||||||
|
@ -627,12 +652,22 @@ void Iteration::EndLastIter() {
|
||||||
pinned_iter_num_ = 0;
|
pinned_iter_num_ = 0;
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
|
|
||||||
SetIterationCompleted();
|
SetIterationEnd();
|
||||||
if (!SummarizeIteration()) {
|
if (!SummarizeIteration()) {
|
||||||
MS_LOG(WARNING) << "Summarizing iteration data failed.";
|
MS_LOG(WARNING) << "Summarizing iteration data failed.";
|
||||||
}
|
}
|
||||||
iteration_num_++;
|
iteration_num_++;
|
||||||
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_);
|
||||||
|
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(server_node_);
|
||||||
|
if (server_node_->rank_id() == kLeaderServerRank) {
|
||||||
|
// Save current iteration number for recovery.
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(server_recovery_);
|
||||||
|
if (!server_recovery_->Save(iteration_num_)) {
|
||||||
|
MS_LOG(WARNING) << "Can't save current iteration number into persistent storage.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Server::GetInstance().CancelSafeMode();
|
Server::GetInstance().CancelSafeMode();
|
||||||
iteration_state_cv_.notify_all();
|
iteration_state_cv_.notify_all();
|
||||||
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n";
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "fl/server/round.h"
|
#include "fl/server/round.h"
|
||||||
#include "fl/server/local_meta_store.h"
|
#include "fl/server/local_meta_store.h"
|
||||||
#include "fl/server/iteration_metrics.h"
|
#include "fl/server/iteration_metrics.h"
|
||||||
|
#include "fl/server/server_recovery.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -52,7 +53,7 @@ class Iteration {
|
||||||
// Register callbacks for other servers to synchronize iteration information from leader server.
|
// Register callbacks for other servers to synchronize iteration information from leader server.
|
||||||
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
void RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunicator> &communicator);
|
||||||
|
|
||||||
// Register event callbacks for iteration state synchronization.
|
// Register event callback for iteration state synchronization.
|
||||||
void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node);
|
void RegisterEventCallback(const std::shared_ptr<ps::core::ServerNode> &server_node);
|
||||||
|
|
||||||
// Add a round for the iteration. This method will be called multiple times for each round.
|
// Add a round for the iteration. This method will be called multiple times for each round.
|
||||||
|
@ -70,14 +71,14 @@ class Iteration {
|
||||||
|
|
||||||
// This method will control servers to proceed to next iteration.
|
// This method will control servers to proceed to next iteration.
|
||||||
// There's communication between leader and follower servers in this method.
|
// There's communication between leader and follower servers in this method.
|
||||||
// The server moves to next iteration only after the last round finishes or the time expires.
|
// The server moves to the next iteration only after the last round finishes or the timer expires.
|
||||||
void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
|
void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
|
||||||
|
|
||||||
// Set current iteration state to running and trigger events about kIterationRunning.
|
// Set current iteration state to running and trigger the event.
|
||||||
void SetIterationRunning();
|
void SetIterationRunning();
|
||||||
|
|
||||||
// Set current iteration state to completed and trigger the event about kIterationCompleted.
|
// Set current iteration state to end and trigger the event.
|
||||||
void SetIterationCompleted();
|
void SetIterationEnd();
|
||||||
|
|
||||||
// The barrier function for elastic scaling. The scaling out/in operation should be done only after this iteration is
|
// The barrier function for elastic scaling. The scaling out/in operation should be done only after this iteration is
|
||||||
// completed.
|
// completed.
|
||||||
|
@ -118,6 +119,12 @@ class Iteration {
|
||||||
// Need to wait all the rounds to finish before proceed to next iteration.
|
// Need to wait all the rounds to finish before proceed to next iteration.
|
||||||
void WaitAllRoundsFinish() const;
|
void WaitAllRoundsFinish() const;
|
||||||
|
|
||||||
|
// Set server's recovery handler.
|
||||||
|
void set_recovery_handler(const std::shared_ptr<ServerRecovery> &server_recovery);
|
||||||
|
|
||||||
|
// Synchronize server iteration after another server's recovery is completed.
|
||||||
|
bool SyncAfterRecovery(uint64_t iteration_num);
|
||||||
|
|
||||||
// The round kernels whose Launch method has not returned yet.
|
// The round kernels whose Launch method has not returned yet.
|
||||||
std::atomic_uint32_t running_round_num_;
|
std::atomic_uint32_t running_round_num_;
|
||||||
|
|
||||||
|
@ -150,10 +157,10 @@ class Iteration {
|
||||||
Iteration &operator=(const Iteration &) = delete;
|
Iteration &operator=(const Iteration &) = delete;
|
||||||
|
|
||||||
// The server does not need to handle the iteration events for now.
|
// The server does not need to handle the iteration events for now.
|
||||||
void HandleIterationRunningEvent() {}
|
void ProcessIterationRunningEvent() {}
|
||||||
void HandleIterationCompletedEvent() {}
|
void ProcessIterationEndEvent() {}
|
||||||
|
|
||||||
// Synchronize iteration form the leader server(Rank 0).
|
// Synchronize iteration from the leader server(Rank 0).
|
||||||
bool SyncIteration(uint32_t rank);
|
bool SyncIteration(uint32_t rank);
|
||||||
void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
void HandleSyncIterationRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
|
@ -165,13 +172,13 @@ class Iteration {
|
||||||
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
|
bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason);
|
||||||
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode.
|
// Step 2: leader server broadcasts to all follower servers to prepare for next iteration and switch to safemode..
|
||||||
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
||||||
void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
void HandlePrepareForNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
// The server prepare for the next iteration. This method will switch the server to safemode.
|
// The server prepare for the next iteration. This method will switch the server to safemode.
|
||||||
void PrepareForNextIter();
|
void PrepareForNextIter();
|
||||||
|
|
||||||
// Step 3: leader server broadcast to all follower servers to move to next iteration.
|
// Step 3: leader server broadcasts to all follower servers to move to next iteration.
|
||||||
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason);
|
||||||
void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
void HandleMoveToNextIterRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
// Move to next iteration. Store last iterations model and reset all the rounds.
|
// Move to next iteration. Store last iterations model and reset all the rounds.
|
||||||
|
@ -201,6 +208,9 @@ class Iteration {
|
||||||
// All the rounds in the server.
|
// All the rounds in the server.
|
||||||
std::vector<std::shared_ptr<Round>> rounds_;
|
std::vector<std::shared_ptr<Round>> rounds_;
|
||||||
|
|
||||||
|
// The recovery object for server.
|
||||||
|
std::shared_ptr<ServerRecovery> server_recovery_;
|
||||||
|
|
||||||
// The iteration is either running or completed at any time.
|
// The iteration is either running or completed at any time.
|
||||||
std::mutex iteration_state_mtx_;
|
std::mutex iteration_state_mtx_;
|
||||||
std::condition_variable iteration_state_cv_;
|
std::condition_variable iteration_state_cv_;
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "fl/server/iteration_metrics.h"
|
#include "fl/server/iteration_metrics.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include "utils/file_utils.h"
|
||||||
#include "debug/common.h"
|
#include "debug/common.h"
|
||||||
#include "ps/constants.h"
|
#include "ps/constants.h"
|
||||||
|
|
||||||
|
@ -68,7 +69,7 @@ bool IterationMetrics::Initialize() {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IterationMetrics::Summarize() {
|
bool IterationMetrics::Summarize() {
|
||||||
metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out);
|
metrics_file_.open(metrics_file_path_, std::ios::out | std::ios::app);
|
||||||
if (!metrics_file_.is_open()) {
|
if (!metrics_file_.is_open()) {
|
||||||
MS_LOG(ERROR) << "The metrics file is not opened.";
|
MS_LOG(ERROR) << "The metrics file is not opened.";
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -39,20 +39,11 @@ constexpr auto kRejectedClientNum = "rejectedClientNum";
|
||||||
constexpr auto kMetricsAuc = "metricsAuc";
|
constexpr auto kMetricsAuc = "metricsAuc";
|
||||||
constexpr auto kMetricsLoss = "metricsLoss";
|
constexpr auto kMetricsLoss = "metricsLoss";
|
||||||
constexpr auto kIterExecutionTime = "iterationExecutionTime";
|
constexpr auto kIterExecutionTime = "iterationExecutionTime";
|
||||||
|
constexpr auto kMetrics = "metrics";
|
||||||
|
|
||||||
const std::map<InstanceState, std::string> kInstanceStateName = {
|
const std::map<InstanceState, std::string> kInstanceStateName = {
|
||||||
{InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}};
|
{InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}};
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) {
|
|
||||||
if (!json.contains(key)) {
|
|
||||||
MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump();
|
|
||||||
}
|
|
||||||
return json[key].get<T>();
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr auto kMetrics = "metrics";
|
|
||||||
|
|
||||||
class IterationMetrics {
|
class IterationMetrics {
|
||||||
public:
|
public:
|
||||||
explicit IterationMetrics(const std::string &config_file)
|
explicit IterationMetrics(const std::string &config_file)
|
||||||
|
|
|
@ -19,6 +19,13 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
|
IterationTimer::~IterationTimer() {
|
||||||
|
running_ = false;
|
||||||
|
if (monitor_thread_.joinable()) {
|
||||||
|
monitor_thread_.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
void IterationTimer::Start(const std::chrono::milliseconds &duration) {
|
||||||
if (running_.load()) {
|
if (running_.load()) {
|
||||||
MS_LOG(WARNING) << "The timer already started.";
|
MS_LOG(WARNING) << "The timer already started.";
|
||||||
|
@ -50,7 +57,7 @@ void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) const {
|
bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) {
|
||||||
return timestamp > end_time_ ? true : false;
|
return timestamp > end_time_ ? true : false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace server {
|
||||||
class IterationTimer {
|
class IterationTimer {
|
||||||
public:
|
public:
|
||||||
IterationTimer() : running_(false), end_time_(0) {}
|
IterationTimer() : running_(false), end_time_(0) {}
|
||||||
~IterationTimer() = default;
|
~IterationTimer();
|
||||||
|
|
||||||
// Start timing. The timer will stop after parameter 'duration' milliseconds.
|
// Start timing. The timer will stop after parameter 'duration' milliseconds.
|
||||||
void Start(const std::chrono::milliseconds &duration);
|
void Start(const std::chrono::milliseconds &duration);
|
||||||
|
@ -42,7 +42,7 @@ class IterationTimer {
|
||||||
void SetTimeOutCallBack(const TimeOutCb &timeout_cb);
|
void SetTimeOutCallBack(const TimeOutCb &timeout_cb);
|
||||||
|
|
||||||
// Judge whether current timestamp is out of time window's range since the Start function is called.
|
// Judge whether current timestamp is out of time window's range since the Start function is called.
|
||||||
bool IsTimeOut(const std::chrono::milliseconds ×tamp) const;
|
bool IsTimeOut(const std::chrono::milliseconds ×tamp);
|
||||||
|
|
||||||
// Judge whether the timer is keeping timing.
|
// Judge whether the timer is keeping timing.
|
||||||
bool IsRunning() const;
|
bool IsRunning() const;
|
||||||
|
|
|
@ -36,59 +36,10 @@ class DenseGradAccumKernel : public AggregationKernel {
|
||||||
DenseGradAccumKernel() = default;
|
DenseGradAccumKernel() = default;
|
||||||
~DenseGradAccumKernel() override = default;
|
~DenseGradAccumKernel() override = default;
|
||||||
|
|
||||||
void InitKernel(const CNodePtr &kernel_node) override {
|
void InitKernel(const CNodePtr &) override { return; }
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
||||||
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
|
|
||||||
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
|
|
||||||
kNameToIdxMap.at(cnode_name).at("inputs").count("grad") == 0) {
|
|
||||||
MS_LOG(EXCEPTION) << "Can't find index info of grad for kernel " << cnode_name;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
size_t cnode_grad_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("grad");
|
|
||||||
std::vector<size_t> grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_grad_idx);
|
|
||||||
size_t grad_size = std::accumulate(grad_shape.begin(), grad_shape.end(), sizeof(T), std::multiplies<size_t>());
|
|
||||||
input_size_list_.push_back(grad_size);
|
|
||||||
size_t new_grad_size = grad_size;
|
|
||||||
input_size_list_.push_back(new_grad_size);
|
|
||||||
GenerateReuseKernelNodeInfo();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &) override {
|
const std::vector<AddressPtr> &) override {
|
||||||
if (inputs.size() != kDenseGradAccumKernelInputsNum) {
|
|
||||||
MS_LOG(ERROR) << "The inputs number of DenseGradAccumKernel should be 2, but got " << inputs.size();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0], false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1], false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false);
|
|
||||||
|
|
||||||
if (accum_count_ == 0) {
|
|
||||||
int ret = memset_s(inputs[0]->addr, inputs[0]->size, 0x00, inputs[0]->size);
|
|
||||||
if (ret != 0) {
|
|
||||||
MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
T *grad_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
|
||||||
T *new_grad_addr = reinterpret_cast<T *>(inputs[1]->addr);
|
|
||||||
for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) {
|
|
||||||
grad_addr[i] += new_grad_addr[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
accum_count_++;
|
|
||||||
if (accum_count_ > done_count_) {
|
|
||||||
MS_LOG(ERROR) << "accum_count_ should not be greater than done_count_ " << done_count_;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (accum_count_ == done_count_) {
|
|
||||||
for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) {
|
|
||||||
grad_addr[i] /= done_count_;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -108,8 +108,8 @@ class FedAvgKernel : public AggregationKernel {
|
||||||
done_ = true;
|
done_ = true;
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
GenerateReuseKernelNodeInfo();
|
|
||||||
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_});
|
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_});
|
||||||
|
GenerateReuseKernelNodeInfo();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,14 +119,9 @@ class FedAvgKernel : public AggregationKernel {
|
||||||
MS_LOG(ERROR) << "The inputs number of FedAvgKernel should be 4, but got " << inputs.size();
|
MS_LOG(ERROR) << "The inputs number of FedAvgKernel should be 4, but got " << inputs.size();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0], false);
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1], false);
|
MS_ERROR_IF_NULL_W_RET_VAL(inputs[i]->addr, false);
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[2], false);
|
}
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[3], false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[2]->addr, false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(inputs[3]->addr, false);
|
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(weight_mutex_);
|
std::unique_lock<std::mutex> lock(weight_mutex_);
|
||||||
// The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication
|
// The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication
|
||||||
|
@ -181,7 +176,7 @@ class FedAvgKernel : public AggregationKernel {
|
||||||
bool ReInitForUpdatingHyperParams(size_t aggr_threshold) override {
|
bool ReInitForUpdatingHyperParams(size_t aggr_threshold) override {
|
||||||
done_count_ = aggr_threshold;
|
done_count_ = aggr_threshold;
|
||||||
if (!DistributedCountService::GetInstance().ReInitCounter(name_, done_count_)) {
|
if (!DistributedCountService::GetInstance().ReInitCounter(name_, done_count_)) {
|
||||||
MS_LOG(ERROR) << "Reinitializing counter for " << name_ << " failed.";
|
MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "utils/hash_map.h"
|
#include <unordered_map>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
#include "fl/server/kernel/params_info.h"
|
#include "fl/server/kernel/params_info.h"
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ class KernelFactory {
|
||||||
|
|
||||||
// Generally, a server kernel can correspond to several ParamsInfo which is registered by the method 'Register' in
|
// Generally, a server kernel can correspond to several ParamsInfo which is registered by the method 'Register' in
|
||||||
// server kernel's *.cc files.
|
// server kernel's *.cc files.
|
||||||
mindspore::HashMap<std::string, std::vector<std::pair<ParamsInfo, C>>> name_to_creator_map_;
|
std::unordered_map<std::string, std::vector<std::pair<ParamsInfo, C>>> name_to_creator_map_;
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -21,49 +21,7 @@ namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) {
|
bool OptimizerKernelFactory::Matched(const ParamsInfo &, const CNodePtr &) { return true; }
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
||||||
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
|
|
||||||
if (kNameToIdxMap.count(cnode_name) == 0) {
|
|
||||||
MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto input_name_to_idx = kNameToIdxMap.at(cnode_name).at("inputs");
|
|
||||||
size_t input_num = params_info.inputs_num();
|
|
||||||
for (size_t i = 0; i < input_num; i++) {
|
|
||||||
auto one_input_name_type = params_info.inputs_name_type(i);
|
|
||||||
std::string name = one_input_name_type.first;
|
|
||||||
if (input_name_to_idx.count(name) == 0) {
|
|
||||||
MS_LOG(EXCEPTION) << cnode_name << " does not have input named " << name;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
size_t input_idx = input_name_to_idx.at(name);
|
|
||||||
TypeId kernel_node_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_idx);
|
|
||||||
TypeId registered_input_type = one_input_name_type.second;
|
|
||||||
if (registered_input_type != kernel_node_input_type) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto output_name_to_idx = kNameToIdxMap.at(cnode_name).at("outputs");
|
|
||||||
size_t output_num = params_info.outputs_num();
|
|
||||||
for (size_t i = 0; i < output_num; i++) {
|
|
||||||
auto one_output_name_type = params_info.outputs_name_type(i);
|
|
||||||
std::string name = one_output_name_type.first;
|
|
||||||
if (output_name_to_idx.count(name) == 0) {
|
|
||||||
MS_LOG(EXCEPTION) << cnode_name << " does not have output named " << name;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
size_t output_idx = output_name_to_idx.at(name);
|
|
||||||
TypeId kernel_node_output_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_idx);
|
|
||||||
TypeId registered_output_type = one_output_name_type.second;
|
|
||||||
if (registered_output_type != kernel_node_output_type) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace fl
|
} // namespace fl
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
#include "schema/cipher_generated.h"
|
#include "schema/cipher_generated.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -29,16 +30,50 @@ void ClientListKernel::InitKernel(size_t) {
|
||||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
}
|
}
|
||||||
|
|
||||||
executor_ = &Executor::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
|
||||||
if (!executor_->initialized()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
cipher_init_ = &armour::CipherInit::GetInstance();
|
cipher_init_ = &armour::CipherInit::GetInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sigVerifyResult ClientListKernel::VerifySignature(const schema::GetClientList *get_clients_req) {
|
||||||
|
std::string fl_id = get_clients_req->fl_id()->str();
|
||||||
|
std::string timestamp = get_clients_req->timestamp()->str();
|
||||||
|
int iteration = get_clients_req->iteration();
|
||||||
|
std::string iter_str = std::to_string(iteration);
|
||||||
|
auto fbs_signature = get_clients_req->signature();
|
||||||
|
std::vector<unsigned char> signature;
|
||||||
|
if (fbs_signature == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "signature in get_clients_req is nullptr";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
std::map<std::string, std::string> key_attestations;
|
||||||
|
const fl::PBMetadata &key_attestations_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
|
||||||
|
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
|
||||||
|
auto iter = key_attestation_pb.key_attestations().begin();
|
||||||
|
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
|
||||||
|
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
|
||||||
|
}
|
||||||
|
if (key_attestations.find(fl_id) == key_attestations.end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned char> src_data;
|
||||||
|
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
|
||||||
|
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
|
||||||
|
return sigVerifyResult::PASSED;
|
||||||
|
}
|
||||||
|
|
||||||
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
|
bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
|
||||||
const std::shared_ptr<server::FBBuilder> &fbb) {
|
const std::shared_ptr<server::FBBuilder> &fbb) {
|
||||||
std::vector<string> client_list;
|
std::vector<string> client_list;
|
||||||
|
@ -48,7 +83,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
|
||||||
if (!LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
if (!LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) {
|
||||||
MS_LOG(ERROR) << "update_model_client_threshold is not set.";
|
MS_LOG(ERROR) << "update_model_client_threshold is not set.";
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_threshold is not set.",
|
BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_threshold is not set.",
|
||||||
empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value<uint64_t>(kCtxUpdateModelThld);
|
||||||
|
@ -57,11 +92,11 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
|
||||||
for (size_t i = 0; i < IntToSize(client_list_pb.fl_id_size()); ++i) {
|
for (size_t i = 0; i < IntToSize(client_list_pb.fl_id_size()); ++i) {
|
||||||
client_list.push_back(client_list_pb.fl_id(SizeToInt(i)));
|
client_list.push_back(client_list_pb.fl_id(SizeToInt(i)));
|
||||||
}
|
}
|
||||||
if (static_cast<uint64_t>(client_list.size()) < update_model_client_needed) {
|
if (client_list.size() < update_model_client_needed) {
|
||||||
MS_LOG(INFO) << "The server is not ready. update_model_client_needed: " << update_model_client_needed;
|
MS_LOG(INFO) << "The server is not ready. update_model_client_needed: " << update_model_client_needed;
|
||||||
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
|
MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size();
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", empty_client_list,
|
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", empty_client_list,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,7 +104,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
|
||||||
std::string reason = "fl_id: " + fl_id + " is not in the update_model_clients";
|
std::string reason = "fl_id: " + fl_id + " is not in the update_model_clients";
|
||||||
MS_LOG(INFO) << reason;
|
MS_LOG(INFO) << reason;
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, empty_client_list,
|
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, empty_client_list,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,34 +114,27 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient
|
||||||
std::string reason = "update get update model clients failed";
|
std::string reason = "update get update model clients failed";
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, reason, empty_client_list,
|
BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, reason, empty_client_list,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str())) {
|
if (!DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str())) {
|
||||||
std::string reason = "Counting for get user list request failed. Please retry later.";
|
std::string reason = "Counting for get user list request failed. Please retry later.";
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list,
|
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "send clients_list succeed!";
|
|
||||||
MS_LOG(INFO) << "UpdateModel client list: ";
|
|
||||||
for (size_t i = 0; i < client_list.size(); ++i) {
|
|
||||||
MS_LOG(INFO) << " fl_id : " << client_list[i];
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "update_model_client_needed: " << update_model_client_needed;
|
MS_LOG(INFO) << "update_model_client_needed: " << update_model_client_needed;
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
|
BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
MS_LOG(INFO) << "Launching ClientListKernel, Iteration number is " << iter_num;
|
||||||
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration;
|
|
||||||
clock_t start_time = clock();
|
|
||||||
|
|
||||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
std::string reason = "inputs or outputs size is invalid.";
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
|
@ -121,26 +149,66 @@ bool ClientListKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::vector<string> client_list;
|
std::vector<string> client_list;
|
||||||
|
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||||
|
if (!verifier.VerifyBuffer<schema::GetClientList>()) {
|
||||||
|
std::string reason = "The schema of GetClientList is invalid.";
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
|
const schema::GetClientList *get_clients_req = flatbuffers::GetRoot<schema::GetClientList>(req_data);
|
||||||
|
if (get_clients_req == nullptr) {
|
||||||
|
std::string reason = "Building flatbuffers schema failed for GetClientList.";
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// verify signature
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
sigVerifyResult verify_result = VerifySignature(get_clients_req);
|
||||||
|
if (verify_result == sigVerifyResult::FAILED) {
|
||||||
|
std::string reason = "verify signature failed.";
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||||
|
std::string reason = "verify signature timestamp failed.";
|
||||||
|
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, client_list,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (verify_result == sigVerifyResult::PASSED) {
|
||||||
|
MS_LOG(INFO) << "verify signature passed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t iter_client = IntToSize(get_clients_req->iteration());
|
size_t iter_client = IntToSize(get_clients_req->iteration());
|
||||||
if (iter_num != iter_client) {
|
if (iter_num != iter_client) {
|
||||||
MS_LOG(ERROR) << "client list iteration number is invalid: server now iteration is " << iter_num
|
MS_LOG(ERROR) << "client list iteration number is invalid: server now iteration is " << iter_num
|
||||||
<< ". client request iteration is " << iter_client;
|
<< ". client request iteration is " << iter_client;
|
||||||
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
|
BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
MS_LOG(ERROR) << "Current amount for GetClientList is enough.";
|
MS_LOG(WARNING) << "Current amount for GetClientList is enough.";
|
||||||
}
|
}
|
||||||
|
|
||||||
(void)DealClient(iter_num, get_clients_req, fbb);
|
if (!DealClient(iter_num, get_clients_req, fbb)) {
|
||||||
|
MS_LOG(WARNING) << "Get Client List not ready.";
|
||||||
|
}
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
|
||||||
MS_LOG(INFO) << "client_list_kernel success time is : " << duration;
|
|
||||||
return true;
|
return true;
|
||||||
} // namespace fl
|
} // namespace fl
|
||||||
|
|
||||||
|
@ -153,28 +221,28 @@ bool ClientListKernel::Reset() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ClientListKernel::BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &client_list_resp_builder,
|
void ClientListKernel::BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &fbb,
|
||||||
const schema::ResponseCode retcode, const string &reason,
|
const schema::ResponseCode retcode, const string &reason,
|
||||||
std::vector<std::string> clients, const string &next_req_time,
|
std::vector<std::string> clients, const string &next_req_time,
|
||||||
const int iteration) {
|
const size_t iteration) {
|
||||||
auto rsp_reason = client_list_resp_builder->CreateString(reason);
|
auto rsp_reason = fbb->CreateString(reason);
|
||||||
auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time);
|
auto rsp_next_req_time = fbb->CreateString(next_req_time);
|
||||||
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
|
std::vector<flatbuffers::Offset<flatbuffers::String>> clients_vector;
|
||||||
for (auto client : clients) {
|
for (auto client : clients) {
|
||||||
auto client_fb = client_list_resp_builder->CreateString(client);
|
auto client_fb = fbb->CreateString(client);
|
||||||
clients_vector.push_back(client_fb);
|
clients_vector.push_back(client_fb);
|
||||||
MS_LOG(WARNING) << "update client list: ";
|
MS_LOG(WARNING) << "update client list: ";
|
||||||
MS_LOG(WARNING) << client;
|
MS_LOG(WARNING) << client;
|
||||||
}
|
}
|
||||||
auto clients_fb = client_list_resp_builder->CreateVector(clients_vector);
|
auto clients_fb = fbb->CreateVector(clients_vector);
|
||||||
schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get()));
|
schema::ReturnClientListBuilder rsp_builder(*(fbb.get()));
|
||||||
rsp_builder.add_retcode(retcode);
|
rsp_builder.add_retcode(SizeToInt(retcode));
|
||||||
rsp_builder.add_reason(rsp_reason);
|
rsp_builder.add_reason(rsp_reason);
|
||||||
rsp_builder.add_clients(clients_fb);
|
rsp_builder.add_clients(clients_fb);
|
||||||
rsp_builder.add_iteration(iteration);
|
rsp_builder.add_iteration(SizeToInt(iteration));
|
||||||
rsp_builder.add_next_req_time(rsp_next_req_time);
|
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||||
auto rsp_exchange_keys = rsp_builder.Finish();
|
auto rsp_exchange_keys = rsp_builder.Finish();
|
||||||
client_list_resp_builder->Finish(rsp_exchange_keys);
|
fbb->Finish(rsp_exchange_keys);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,9 @@ namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
// results of signature verification
|
||||||
|
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
|
||||||
|
|
||||||
class ClientListKernel : public RoundKernel {
|
class ClientListKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
ClientListKernel() = default;
|
ClientListKernel() = default;
|
||||||
|
@ -37,12 +40,13 @@ class ClientListKernel : public RoundKernel {
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) override;
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
bool Reset() override;
|
bool Reset() override;
|
||||||
void BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &client_list_resp_builder,
|
void BuildClientListRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
const schema::ResponseCode retcode, const string &reason, std::vector<std::string> clients,
|
const string &reason, std::vector<std::string> clients, const string &next_req_time,
|
||||||
const string &next_req_time, const int iteration);
|
const size_t iteration);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
armour::CipherInit *cipher_init_;
|
armour::CipherInit *cipher_init_;
|
||||||
|
sigVerifyResult VerifySignature(const schema::GetClientList *get_clients_req);
|
||||||
bool DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
|
bool DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req,
|
||||||
const std::shared_ptr<server::FBBuilder> &fbb);
|
const std::shared_ptr<server::FBBuilder> &fbb);
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -27,24 +28,15 @@ void ExchangeKeysKernel::InitKernel(size_t) {
|
||||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
}
|
}
|
||||||
|
|
||||||
executor_ = &Executor::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
|
||||||
if (!executor_->initialized()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
cipher_key_ = &armour::CipherKeys::GetInstance();
|
cipher_key_ = &armour::CipherKeys::GetInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num) {
|
bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const size_t iter_num) {
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
std::string reason = "Current amount for exchangeKey is enough. Please retry later.";
|
std::string reason = "Current amount for exchangeKey is enough. Please retry later.";
|
||||||
cipher_key_->BuildExchangeKeysRsp(
|
cipher_key_->BuildExchangeKeysRsp(
|
||||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
|
||||||
IntToSize(iter_num));
|
|
||||||
MS_LOG(WARNING) << reason;
|
MS_LOG(WARNING) << reason;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -53,30 +45,84 @@ bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr<FBB
|
||||||
|
|
||||||
bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb,
|
bool ExchangeKeysKernel::CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
const schema::RequestExchangeKeys *exchange_keys_req,
|
const schema::RequestExchangeKeys *exchange_keys_req,
|
||||||
const int iter_num) {
|
const size_t iter_num) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(exchange_keys_req, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(exchange_keys_req, false);
|
||||||
if (!DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str())) {
|
if (!DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str())) {
|
||||||
std::string reason = "Counting for exchange kernel request failed. Please retry later.";
|
std::string reason = "Counting for exchange kernel request failed. Please retry later.";
|
||||||
cipher_key_->BuildExchangeKeysRsp(
|
cipher_key_->BuildExchangeKeysRsp(
|
||||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)),
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), iter_num);
|
||||||
IntToSize(iter_num));
|
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sigVerifyResult ExchangeKeysKernel::VerifySignature(const schema::RequestExchangeKeys *exchange_keys_req) {
|
||||||
|
std::string fl_id = exchange_keys_req->fl_id()->str();
|
||||||
|
std::string timestamp = exchange_keys_req->timestamp()->str();
|
||||||
|
int iteration = exchange_keys_req->iteration();
|
||||||
|
std::string iter_str = std::to_string(iteration);
|
||||||
|
auto fbs_signature = exchange_keys_req->signature();
|
||||||
|
std::vector<unsigned char> signature;
|
||||||
|
if (fbs_signature == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "signature in exchange_keys_req is nullptr";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
std::map<std::string, std::string> key_attestations;
|
||||||
|
const fl::PBMetadata &key_attestations_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
|
||||||
|
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
|
||||||
|
auto iter = key_attestation_pb.key_attestations().begin();
|
||||||
|
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
|
||||||
|
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
|
||||||
|
}
|
||||||
|
if (key_attestations.find(fl_id) == key_attestations.end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto fbs_cpk = exchange_keys_req->c_pk();
|
||||||
|
auto fbs_spk = exchange_keys_req->s_pk();
|
||||||
|
if (fbs_cpk == nullptr || fbs_spk == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "public key from exchange_keys_req is null";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
size_t spk_len = fbs_spk->size();
|
||||||
|
size_t cpk_len = fbs_cpk->size();
|
||||||
|
std::vector<uint8_t> cpk(cpk_len);
|
||||||
|
std::vector<uint8_t> spk(spk_len);
|
||||||
|
bool ret_create_code_cpk = mindspore::armour::CreateArray<uint8_t>(&cpk, *fbs_cpk);
|
||||||
|
bool ret_create_code_spk = mindspore::armour::CreateArray<uint8_t>(&spk, *fbs_spk);
|
||||||
|
if (!(ret_create_code_cpk && ret_create_code_spk)) {
|
||||||
|
MS_LOG(ERROR) << "create array for public keys failed";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned char> src_data;
|
||||||
|
(void)src_data.insert(src_data.end(), cpk.begin(), cpk.end());
|
||||||
|
(void)src_data.insert(src_data.end(), spk.begin(), spk.end());
|
||||||
|
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
|
||||||
|
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
|
||||||
|
return sigVerifyResult::PASSED;
|
||||||
|
}
|
||||||
|
|
||||||
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
MS_LOG(INFO) << "Launching ExchangeKey kernel.";
|
|
||||||
bool response = false;
|
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
MS_LOG(INFO) << "Launching ExchangeKey kernel, ITERATION NUMBER IS : " << iter_num;
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is "
|
bool response = false;
|
||||||
<< total_duration;
|
|
||||||
clock_t start_time = clock();
|
|
||||||
|
|
||||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
std::string reason = "inputs or outputs size is invalid.";
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
|
@ -91,11 +137,56 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ReachThresholdForExchangeKeys(fbb, SizeToInt(iter_num))) {
|
if (ReachThresholdForExchangeKeys(fbb, iter_num)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||||
|
if (!verifier.VerifyBuffer<schema::RequestExchangeKeys>()) {
|
||||||
|
std::string reason = "The schema of RequestExchangeKeys is invalid.";
|
||||||
|
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
|
const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot<schema::RequestExchangeKeys>(req_data);
|
||||||
|
if (exchange_keys_req == nullptr) {
|
||||||
|
std::string reason = "Building flatbuffers schema failed for ExchangeKeys.";
|
||||||
|
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify signature
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
sigVerifyResult verify_result = VerifySignature(exchange_keys_req);
|
||||||
|
if (verify_result == sigVerifyResult::FAILED) {
|
||||||
|
std::string reason = "verify signature failed.";
|
||||||
|
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||||
|
std::string reason = "verify signature timestamp failed.";
|
||||||
|
cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::PASSED) {
|
||||||
|
MS_LOG(INFO) << "verify signature passed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t iter_client = IntToSize(exchange_keys_req->iteration());
|
size_t iter_client = IntToSize(exchange_keys_req->iteration());
|
||||||
if (iter_num != iter_client) {
|
if (iter_num != iter_client) {
|
||||||
MS_LOG(ERROR) << "ExchangeKeys iteration number is invalid: server now iteration is " << iter_num
|
MS_LOG(ERROR) << "ExchangeKeys iteration number is invalid: server now iteration is " << iter_num
|
||||||
|
@ -107,19 +198,16 @@ bool ExchangeKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
||||||
}
|
}
|
||||||
response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
|
response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb);
|
||||||
if (!response) {
|
if (!response) {
|
||||||
MS_LOG(WARNING) << "update exchange keys is failed.";
|
MS_LOG(ERROR) << "update exchange keys is failed.";
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (!CountForExchangeKeys(fbb, exchange_keys_req, SizeToInt(iter_num))) {
|
if (!CountForExchangeKeys(fbb, exchange_keys_req, iter_num)) {
|
||||||
MS_LOG(ERROR) << "count for exchange keys failed.";
|
MS_LOG(ERROR) << "count for exchange keys failed.";
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
|
||||||
MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,11 +25,15 @@
|
||||||
#include "fl/server/kernel/round/round_kernel_factory.h"
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
#include "fl/server/executor.h"
|
#include "fl/server/executor.h"
|
||||||
#include "fl/armour/cipher/cipher_keys.h"
|
#include "fl/armour/cipher/cipher_keys.h"
|
||||||
|
#include "fl/armour/cipher/cipher_meta_storage.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
// results of signature verification
|
||||||
|
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
|
||||||
|
|
||||||
class ExchangeKeysKernel : public RoundKernel {
|
class ExchangeKeysKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
ExchangeKeysKernel() = default;
|
ExchangeKeysKernel() = default;
|
||||||
|
@ -43,9 +47,10 @@ class ExchangeKeysKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_;
|
||||||
armour::CipherKeys *cipher_key_;
|
armour::CipherKeys *cipher_key_;
|
||||||
bool ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const int iter_num);
|
sigVerifyResult VerifySignature(const schema::RequestExchangeKeys *exchange_keys_req);
|
||||||
|
bool ReachThresholdForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const size_t iter_num);
|
||||||
bool CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestExchangeKeys *exchange_keys_req,
|
bool CountForExchangeKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestExchangeKeys *exchange_keys_req,
|
||||||
const int iter_num);
|
const size_t iter_num);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
#include "fl/server/kernel/round/get_keys_kernel.h"
|
#include "fl/server/kernel/round/get_keys_kernel.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -26,24 +28,16 @@ void GetKeysKernel::InitKernel(size_t) {
|
||||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
}
|
}
|
||||||
|
|
||||||
executor_ = &Executor::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
|
||||||
if (!executor_->initialized()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
cipher_key_ = &armour::CipherKeys::GetInstance();
|
cipher_key_ = &armour::CipherKeys::GetInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
|
bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
|
||||||
const int iter_num) {
|
const size_t iter_num) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(get_keys_req, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(get_keys_req, false);
|
||||||
if (!DistributedCountService::GetInstance().Count(name_, get_keys_req->fl_id()->str())) {
|
if (!DistributedCountService::GetInstance().Count(name_, get_keys_req->fl_id()->str())) {
|
||||||
std::string reason = "Counting for getkeys kernel request failed. Please retry later.";
|
std::string reason = "Counting for getkeys kernel request failed. Please retry later.";
|
||||||
cipher_key_->BuildGetKeysRsp(
|
cipher_key_->BuildGetKeysRsp(
|
||||||
fbb, schema::ResponseCode_OutOfTime, IntToSize(iter_num),
|
fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), false);
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)), false);
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
|
@ -51,16 +45,52 @@ bool GetKeysKernel::CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sigVerifyResult GetKeysKernel::VerifySignature(const schema::GetExchangeKeys *get_keys_req) {
|
||||||
|
std::string fl_id = get_keys_req->fl_id()->str();
|
||||||
|
std::string timestamp = get_keys_req->timestamp()->str();
|
||||||
|
int iteration = get_keys_req->iteration();
|
||||||
|
std::string iter_str = std::to_string(iteration);
|
||||||
|
auto fbs_signature = get_keys_req->signature();
|
||||||
|
std::vector<unsigned char> signature;
|
||||||
|
if (fbs_signature == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "signature in get_keys_req is nullptr";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
std::map<std::string, std::string> key_attestations;
|
||||||
|
const fl::PBMetadata &key_attestations_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
|
||||||
|
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
|
||||||
|
auto iter = key_attestation_pb.key_attestations().begin();
|
||||||
|
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
|
||||||
|
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
|
||||||
|
}
|
||||||
|
if (key_attestations.find(fl_id) == key_attestations.end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned char> src_data;
|
||||||
|
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
|
||||||
|
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
|
||||||
|
return sigVerifyResult::PASSED;
|
||||||
|
}
|
||||||
|
|
||||||
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
MS_LOG(INFO) << "Launching GetKeys kernel.";
|
|
||||||
bool response = false;
|
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
MS_LOG(INFO) << "Launching GetKeys kernel, ITERATION NUMBER IS : " << iter_num;
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetKeysKernel allowed Duration Is "
|
bool response = false;
|
||||||
<< total_duration;
|
|
||||||
clock_t start_time = clock();
|
|
||||||
|
|
||||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
std::string reason = "inputs or outputs size is invalid.";
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
|
@ -74,12 +104,54 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough.";
|
MS_LOG(WARNING) << "Current amount for GetKeysKernel is enough.";
|
||||||
|
}
|
||||||
|
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||||
|
if (!verifier.VerifyBuffer<schema::GetExchangeKeys>()) {
|
||||||
|
std::string reason = "The schema of GetExchangeKeys is invalid.";
|
||||||
|
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
|
||||||
|
if (get_exchange_keys_req == nullptr) {
|
||||||
|
std::string reason = "Building flatbuffers schema failed for GetExchangeKeys.";
|
||||||
|
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify signature
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
sigVerifyResult verify_result = VerifySignature(get_exchange_keys_req);
|
||||||
|
if (verify_result == sigVerifyResult::FAILED) {
|
||||||
|
std::string reason = "verify signature failed.";
|
||||||
|
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||||
|
std::string reason = "verify signature timestamp failed.";
|
||||||
|
cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), false);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::PASSED) {
|
||||||
|
MS_LOG(INFO) << "verify signature passed!";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot<schema::GetExchangeKeys>(req_data);
|
|
||||||
size_t iter_client = IntToSize(get_exchange_keys_req->iteration());
|
size_t iter_client = IntToSize(get_exchange_keys_req->iteration());
|
||||||
if (iter_num != iter_client) {
|
if (iter_num != iter_client) {
|
||||||
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
|
MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num
|
||||||
|
@ -91,18 +163,15 @@ bool GetKeysKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vec
|
||||||
}
|
}
|
||||||
response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
|
response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb);
|
||||||
if (!response) {
|
if (!response) {
|
||||||
MS_LOG(WARNING) << "get public keys is failed.";
|
MS_LOG(WARNING) << "get public keys not ready.";
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (!CountForGetKeys(fbb, get_exchange_keys_req, SizeToInt(iter_num))) {
|
if (!CountForGetKeys(fbb, get_exchange_keys_req, iter_num)) {
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
|
||||||
MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,9 @@ namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
// results of signature verification
|
||||||
|
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
|
||||||
|
|
||||||
class GetKeysKernel : public RoundKernel {
|
class GetKeysKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
GetKeysKernel() = default;
|
GetKeysKernel() = default;
|
||||||
|
@ -43,8 +46,9 @@ class GetKeysKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_;
|
||||||
armour::CipherKeys *cipher_key_;
|
armour::CipherKeys *cipher_key_;
|
||||||
|
sigVerifyResult VerifySignature(const schema::GetExchangeKeys *get_keys_req);
|
||||||
bool CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
|
bool CountForGetKeys(const std::shared_ptr<FBBuilder> &fbb, const schema::GetExchangeKeys *get_keys_req,
|
||||||
const int iter_num);
|
const size_t iter_num);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -0,0 +1,281 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "fl/server/kernel/round/get_list_sign_kernel.h"
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include "schema/cipher_generated.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
void GetListSignKernel::InitKernel(size_t) {
|
||||||
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
|
}
|
||||||
|
cipher_init_ = &armour::CipherInit::GetInstance();
|
||||||
|
}
|
||||||
|
|
||||||
|
sigVerifyResult GetListSignKernel::VerifySignature(const schema::RequestAllClientListSign *client_list_sign_req) {
|
||||||
|
std::string fl_id = client_list_sign_req->fl_id()->str();
|
||||||
|
std::string timestamp = client_list_sign_req->timestamp()->str();
|
||||||
|
int iteration = client_list_sign_req->iteration();
|
||||||
|
std::string iter_str = std::to_string(iteration);
|
||||||
|
auto fbs_signature = client_list_sign_req->signature();
|
||||||
|
std::vector<unsigned char> signature;
|
||||||
|
if (fbs_signature == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "signature in client_list_sign_req is nullptr";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
std::map<std::string, std::string> key_attestations;
|
||||||
|
const fl::PBMetadata &key_attestations_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
|
||||||
|
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
|
||||||
|
auto iter = key_attestation_pb.key_attestations().begin();
|
||||||
|
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
|
||||||
|
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
|
||||||
|
}
|
||||||
|
if (key_attestations.find(fl_id) == key_attestations.end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned char> src_data;
|
||||||
|
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
|
||||||
|
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
|
||||||
|
return sigVerifyResult::PASSED;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GetListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
|
MS_LOG(INFO) << "Launching GetListSign kernel, Iteration number is " << iter_num;
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||||
|
void *req_data = inputs[0]->addr;
|
||||||
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::map<std::string, std::vector<unsigned char>> list_signs;
|
||||||
|
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||||
|
if (!verifier.VerifyBuffer<schema::RequestAllClientListSign>()) {
|
||||||
|
std::string reason = "The schema of RequestAllClientListSign is invalid.";
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const schema::RequestAllClientListSign *get_list_sign_req =
|
||||||
|
flatbuffers::GetRoot<schema::RequestAllClientListSign>(req_data);
|
||||||
|
if (get_list_sign_req == nullptr) {
|
||||||
|
std::string reason = "Building flatbuffers schema failed for RequestAllClientListSign.";
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify signature
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
sigVerifyResult verify_result = VerifySignature(get_list_sign_req);
|
||||||
|
if (verify_result == sigVerifyResult::FAILED) {
|
||||||
|
std::string reason = "verify signature failed.";
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||||
|
std::string reason = "verify signature timestamp failed.";
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||||
|
iter_num, list_signs);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::PASSED) {
|
||||||
|
MS_LOG(INFO) << "verify signature passed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t iter_client = IntToSize(get_list_sign_req->iteration());
|
||||||
|
if (iter_num != iter_client) {
|
||||||
|
MS_LOG(ERROR) << "get list sign iteration number is invalid: server now iteration is " << iter_num
|
||||||
|
<< ". client request iteration is " << iter_client;
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.",
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs);
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::string fl_id = get_list_sign_req->fl_id()->str();
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
MS_LOG(WARNING) << "Current amount for GetListSignKernel is enough.";
|
||||||
|
}
|
||||||
|
if (!GetListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_list_sign_req, fbb)) {
|
||||||
|
MS_LOG(WARNING) << "get list signs not ready.";
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::string count_reason = "";
|
||||||
|
if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) {
|
||||||
|
std::string reason = "Counting for get list sign request failed. Please retry later. " + count_reason;
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||||
|
iter_num, list_signs);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GetListSignKernel::GetListSign(const size_t cur_iterator, const std::string &next_req_time,
|
||||||
|
const schema::RequestAllClientListSign *get_list_sign_req,
|
||||||
|
const std::shared_ptr<fl::server::FBBuilder> &fbb) {
|
||||||
|
MS_LOG(INFO) << "CipherMgr::SendClientListSign START";
|
||||||
|
std::map<std::string, std::vector<unsigned char>> client_list_signs_empty;
|
||||||
|
std::map<std::string, std::vector<unsigned char>> client_list_signs_all;
|
||||||
|
const fl::PBMetadata &clients_sign_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientListSigns);
|
||||||
|
const fl::ClientListSign &clients_sign_pb = clients_sign_pb_out.client_list_sign();
|
||||||
|
size_t cur_clients_sign_num = IntToSize(clients_sign_pb.client_list_sign_size());
|
||||||
|
if (cur_clients_sign_num < cipher_init_->push_list_sign_threshold) {
|
||||||
|
MS_LOG(INFO) << "The server is not ready. push_list_sign_needed: " << cipher_init_->push_list_sign_threshold;
|
||||||
|
MS_LOG(INFO) << "now push_sign_client_num: " << clients_sign_pb.client_list_sign_size();
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", next_req_time,
|
||||||
|
cur_iterator, client_list_signs_empty);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> update_model_clients;
|
||||||
|
const PBMetadata update_model_clients_pb_out =
|
||||||
|
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||||
|
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
|
||||||
|
for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) {
|
||||||
|
update_model_clients.push_back(update_model_clients_pb.fl_id(SizeToInt(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = clients_sign_pb.client_list_sign().begin();
|
||||||
|
for (; iter != clients_sign_pb.client_list_sign().end(); ++iter) {
|
||||||
|
std::vector<unsigned char> signature(iter->second.begin(), iter->second.end());
|
||||||
|
(void)client_list_signs_all.emplace(std::pair<std::string, std::vector<unsigned char>>(iter->first, signature));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string fl_id = get_list_sign_req->fl_id()->str();
|
||||||
|
if (client_list_signs_all.find(fl_id) == client_list_signs_all.end()) {
|
||||||
|
std::string reason;
|
||||||
|
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
|
||||||
|
reason = "client not send list signature, but in update model client list.";
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator,
|
||||||
|
client_list_signs_all);
|
||||||
|
} else {
|
||||||
|
reason = "client not send list signature, && client is illegal";
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator,
|
||||||
|
client_list_signs_empty);
|
||||||
|
}
|
||||||
|
MS_LOG(WARNING) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (client_list_signs_all.find(fl_id) != client_list_signs_all.end()) {
|
||||||
|
// the client has sended signature, return false.
|
||||||
|
std::string reason = "The server has received the request, please do not request again.";
|
||||||
|
MS_LOG(WARNING) << reason;
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator,
|
||||||
|
client_list_signs_all);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string reason = "send update model client list signature success. ";
|
||||||
|
BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator,
|
||||||
|
client_list_signs_all);
|
||||||
|
MS_LOG(INFO) << "CipherMgr::Send Client ListSign Success";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GetListSignKernel::Reset() {
|
||||||
|
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
|
MS_LOG(INFO) << "Get List Signature kernel reset!";
|
||||||
|
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||||
|
StopTimer();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetListSignKernel::BuildGetListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb,
|
||||||
|
const schema::ResponseCode retcode, const string &reason,
|
||||||
|
const string &next_req_time, const size_t iteration,
|
||||||
|
const std::map<std::string, std::vector<unsigned char>> &list_signs) {
|
||||||
|
auto rsp_reason = fbb->CreateString(reason);
|
||||||
|
auto rsp_next_req_time = fbb->CreateString(next_req_time);
|
||||||
|
if (list_signs.size() == 0) {
|
||||||
|
schema::ReturnAllClientListSignBuilder rsp_builder(*(fbb.get()));
|
||||||
|
rsp_builder.add_retcode(static_cast<int>(retcode));
|
||||||
|
rsp_builder.add_reason(rsp_reason);
|
||||||
|
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||||
|
rsp_builder.add_iteration(SizeToInt(iteration));
|
||||||
|
auto rsp_get_list_sign = rsp_builder.Finish();
|
||||||
|
fbb->Finish(rsp_get_list_sign);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<flatbuffers::Offset<schema::ClientListSign>> client_list_signs;
|
||||||
|
for (auto iter = list_signs.begin(); iter != list_signs.end(); ++iter) {
|
||||||
|
auto fbs_fl_id = fbb->CreateString(iter->first);
|
||||||
|
auto fbs_sign = fbb->CreateVector(iter->second.data(), iter->second.size());
|
||||||
|
auto cur_sign = schema::CreateClientListSign(*fbb, fbs_fl_id, fbs_sign);
|
||||||
|
client_list_signs.push_back(cur_sign);
|
||||||
|
}
|
||||||
|
auto all_signs = fbb->CreateVector(client_list_signs);
|
||||||
|
schema::ReturnAllClientListSignBuilder rsp_builder(*(fbb.get()));
|
||||||
|
rsp_builder.add_retcode(static_cast<int>(retcode));
|
||||||
|
rsp_builder.add_reason(rsp_reason);
|
||||||
|
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||||
|
rsp_builder.add_iteration(SizeToInt(iteration));
|
||||||
|
rsp_builder.add_client_list_sign(all_signs);
|
||||||
|
auto rsp_get_list_sign = rsp_builder.Finish();
|
||||||
|
fbb->Finish(rsp_get_list_sign);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_ROUND_KERNEL(getListSign, GetListSignKernel)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_LIST_SIGN_KERNEL_H
|
||||||
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_LIST_SIGN_KERNEL_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include "fl/server/common.h"
|
||||||
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
#include "fl/armour/cipher/cipher_init.h"
|
||||||
|
#include "fl/server/executor.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
// results of signature verification
|
||||||
|
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
|
||||||
|
|
||||||
|
class GetListSignKernel : public RoundKernel {
|
||||||
|
public:
|
||||||
|
GetListSignKernel() = default;
|
||||||
|
~GetListSignKernel() override = default;
|
||||||
|
void InitKernel(size_t required_cnt) override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
bool Reset() override;
|
||||||
|
void BuildGetListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
|
const string &reason, const string &next_req_time, const size_t iteration,
|
||||||
|
const std::map<std::string, std::vector<unsigned char>> &list_signs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
armour::CipherInit *cipher_init_;
|
||||||
|
Executor *executor_;
|
||||||
|
size_t iteration_time_window_;
|
||||||
|
sigVerifyResult VerifySignature(const schema::RequestAllClientListSign *client_list_sign_req);
|
||||||
|
bool GetListSign(const size_t cur_iterator, const std::string &next_req_time,
|
||||||
|
const schema::RequestAllClientListSign *client_list_sign_req,
|
||||||
|
const std::shared_ptr<fl::server::FBBuilder> &fbb);
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_LIST_SIGN_KERNEL_H
|
|
@ -67,9 +67,9 @@ bool GetModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std::ve
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
(void)++retry_count_;
|
++retry_count_;
|
||||||
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
|
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
|
||||||
MS_LOG(INFO) << "Launching GetModelKernel retry count is " << retry_count_.load();
|
MS_LOG(INFO) << "Launching GetModelKernel kernel. Retry count is " << retry_count_.load();
|
||||||
}
|
}
|
||||||
|
|
||||||
const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
|
const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot<schema::RequestGetModel>(req_data);
|
||||||
|
@ -95,14 +95,14 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
||||||
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
|
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
|
||||||
std::map<std::string, AddressPtr> feature_maps;
|
std::map<std::string, AddressPtr> feature_maps;
|
||||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t get_model_iter = IntToSize(get_model_req->iteration());
|
size_t get_model_iter = static_cast<size_t>(get_model_req->iteration());
|
||||||
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model();
|
||||||
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
size_t latest_iter_num = iter_to_model.rbegin()->first;
|
||||||
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
|
// If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later.
|
||||||
if ((current_iter == get_model_iter && latest_iter_num != current_iter)) {
|
if (current_iter == get_model_iter && latest_iter_num != current_iter) {
|
||||||
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
|
std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) +
|
||||||
". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" +
|
". Maybe this is because\n" + "1. Client doesn't not send enough update model request.\n" +
|
||||||
"2. Worker has not push all the weights to servers.";
|
"2. Worker has not push weights to server.";
|
||||||
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
|
BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps,
|
||||||
std::to_string(next_req_time));
|
std::to_string(next_req_time));
|
||||||
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
|
if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) {
|
||||||
|
@ -120,7 +120,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons
|
||||||
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
|
feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
|
MS_LOG(INFO) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid()
|
||||||
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
|
<< ", next request time is " << next_req_time << ", current iteration is " << current_iter;
|
||||||
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
|
BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter),
|
||||||
current_iter, feature_maps, std::to_string(next_req_time));
|
current_iter, feature_maps, std::to_string(next_req_time));
|
||||||
|
|
|
@ -51,7 +51,7 @@ class GetModelKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
|
|
||||||
// The time window of one iteration.
|
// The time window of one iteration.
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_{0};
|
||||||
|
|
||||||
// The count of retrying because the iteration is not finished.
|
// The count of retrying because the iteration is not finished.
|
||||||
std::atomic<uint64_t> retry_count_;
|
std::atomic<uint64_t> retry_count_;
|
||||||
|
|
|
@ -30,23 +30,15 @@ void GetSecretsKernel::InitKernel(size_t) {
|
||||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
}
|
}
|
||||||
|
|
||||||
executor_ = &Executor::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
|
||||||
if (!executor_->initialized()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
cipher_share_ = &armour::CipherShares::GetInstance();
|
cipher_share_ = &armour::CipherShares::GetInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
|
bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
const schema::GetShareSecrets *get_secrets_req, const int iter_num) {
|
const schema::GetShareSecrets *get_secrets_req, const size_t iter_num) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false);
|
||||||
if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) {
|
if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) {
|
||||||
std::string reason = "Counting for get secrets kernel request failed. Please retry later.";
|
std::string reason = "Counting for get secrets kernel request failed. Please retry later.";
|
||||||
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, IntToSize(iter_num),
|
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()), nullptr);
|
std::to_string(CURRENT_TIME_MILLI.count()), nullptr);
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
|
@ -54,16 +46,52 @@ bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sigVerifyResult GetSecretsKernel::VerifySignature(const schema::GetShareSecrets *get_secrets_req) {
|
||||||
|
std::string fl_id = get_secrets_req->fl_id()->str();
|
||||||
|
std::string timestamp = get_secrets_req->timestamp()->str();
|
||||||
|
int iteration = get_secrets_req->iteration();
|
||||||
|
std::string iter_str = std::to_string(iteration);
|
||||||
|
auto fbs_signature = get_secrets_req->signature();
|
||||||
|
std::vector<unsigned char> signature;
|
||||||
|
if (fbs_signature == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "signature in get_secrets_req is nullptr";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
std::map<std::string, std::string> key_attestations;
|
||||||
|
const fl::PBMetadata &key_attestations_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
|
||||||
|
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
|
||||||
|
auto iter = key_attestation_pb.key_attestations().begin();
|
||||||
|
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
|
||||||
|
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
|
||||||
|
}
|
||||||
|
if (key_attestations.find(fl_id) == key_attestations.end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned char> src_data;
|
||||||
|
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
|
||||||
|
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
|
||||||
|
return sigVerifyResult::PASSED;
|
||||||
|
}
|
||||||
|
|
||||||
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
bool response = false;
|
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
|
||||||
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
|
std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count());
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
MS_LOG(INFO) << "Launching get secrets kernel, ITERATION NUMBER IS : " << iter_num;
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is "
|
|
||||||
<< total_duration;
|
|
||||||
clock_t start_time = clock();
|
|
||||||
|
|
||||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
std::string reason = "inputs or outputs size is invalid.";
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
|
@ -78,8 +106,46 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||||
|
if (!verifier.VerifyBuffer<schema::GetShareSecrets>()) {
|
||||||
|
std::string reason = "The schema of GetShareSecrets is invalid.";
|
||||||
|
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
|
const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot<schema::GetShareSecrets>(req_data);
|
||||||
|
if (get_secrets_req == nullptr) {
|
||||||
|
std::string reason = "Building flatbuffers schema failed for GetExchangeKeys.";
|
||||||
|
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify signature
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
sigVerifyResult verify_result = VerifySignature(get_secrets_req);
|
||||||
|
if (verify_result == sigVerifyResult::FAILED) {
|
||||||
|
std::string reason = "verify signature failed.";
|
||||||
|
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||||
|
std::string reason = "verify signature timestamp failed.";
|
||||||
|
cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::PASSED) {
|
||||||
|
MS_LOG(INFO) << "verify signature passed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
size_t iter_client = IntToSize(get_secrets_req->iteration());
|
size_t iter_client = IntToSize(get_secrets_req->iteration());
|
||||||
if (iter_num != iter_client) {
|
if (iter_num != iter_client) {
|
||||||
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
|
MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||||
|
@ -93,20 +159,17 @@ bool GetSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
|
MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough.";
|
||||||
}
|
}
|
||||||
|
|
||||||
response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
|
bool response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp);
|
||||||
if (!response) {
|
if (!response) {
|
||||||
MS_LOG(WARNING) << "get secret shares is failed.";
|
MS_LOG(WARNING) << "get secret shares not ready.";
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (!CountForGetSecrets(fbb, get_secrets_req, SizeToInt(iter_num))) {
|
if (!CountForGetSecrets(fbb, get_secrets_req, iter_num)) {
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
|
||||||
MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,9 @@ namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
// results of signature verification
|
||||||
|
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
|
||||||
|
|
||||||
class GetSecretsKernel : public RoundKernel {
|
class GetSecretsKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
GetSecretsKernel() = default;
|
GetSecretsKernel() = default;
|
||||||
|
@ -42,8 +45,9 @@ class GetSecretsKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_;
|
||||||
armour::CipherShares *cipher_share_;
|
armour::CipherShares *cipher_share_;
|
||||||
|
sigVerifyResult VerifySignature(const schema::GetShareSecrets *get_secrets_req);
|
||||||
bool CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::GetShareSecrets *get_secrets_req,
|
bool CountForGetSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::GetShareSecrets *get_secrets_req,
|
||||||
const int iter_num);
|
const size_t iter_num);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -37,6 +37,13 @@ void PullWeightKernel::InitKernel(size_t) {
|
||||||
bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool PullWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
MS_LOG(DEBUG) << "Launching PullWeightKernel kernel.";
|
MS_LOG(DEBUG) << "Launching PullWeightKernel kernel.";
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void *req_data = inputs[0]->addr;
|
void *req_data = inputs[0]->addr;
|
||||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||||
if (fbb == nullptr || req_data == nullptr) {
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
@ -71,7 +78,7 @@ void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
}
|
}
|
||||||
std::map<std::string, AddressPtr> feature_maps = {};
|
std::map<std::string, AddressPtr> feature_maps = {};
|
||||||
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t pull_weight_iter = IntToSize(pull_weight_req->iteration());
|
size_t pull_weight_iter = static_cast<size_t>(pull_weight_req->iteration());
|
||||||
// The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry.
|
// The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry.
|
||||||
if (pull_weight_iter != current_iter) {
|
if (pull_weight_iter != current_iter) {
|
||||||
std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) +
|
std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) +
|
||||||
|
@ -91,7 +98,7 @@ void PullWeightKernel::PullWeight(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
weight_names.push_back(weights_names_fbs->Get(i)->str());
|
weight_names.push_back(weights_names_fbs->Get(i)->str());
|
||||||
}
|
}
|
||||||
if (!executor_->IsWeightAggrDone(weight_names) || !executor_->unmasked()) {
|
if (!executor_->IsWeightAggrDone(weight_names) || !executor_->unmasked()) {
|
||||||
(void)++retry_count_;
|
++retry_count_;
|
||||||
std::string reason = "The aggregation for the weights is not done yet.";
|
std::string reason = "The aggregation for the weights is not done yet.";
|
||||||
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps);
|
BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps);
|
||||||
if (retry_count_.load() % kPrintPullWeightForEveryRetryTime == 1) {
|
if (retry_count_.load() % kPrintPullWeightForEveryRetryTime == 1) {
|
||||||
|
@ -134,7 +141,7 @@ void PullWeightKernel::BuildPullWeightRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps);
|
||||||
|
|
||||||
schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get()));
|
schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get()));
|
||||||
rsp_pull_weight_builder.add_retcode(static_cast<int>(retcode));
|
rsp_pull_weight_builder.add_retcode(SizeToInt(retcode));
|
||||||
rsp_pull_weight_builder.add_reason(fbs_reason);
|
rsp_pull_weight_builder.add_reason(fbs_reason);
|
||||||
rsp_pull_weight_builder.add_iteration(SizeToInt(iteration));
|
rsp_pull_weight_builder.add_iteration(SizeToInt(iteration));
|
||||||
rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector);
|
rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector);
|
||||||
|
|
|
@ -0,0 +1,277 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "fl/server/kernel/round/push_list_sign_kernel.h"
|
||||||
|
#include <utility>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include "schema/cipher_generated.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
void PushListSignKernel::InitKernel(size_t) {
|
||||||
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
|
}
|
||||||
|
cipher_init_ = &armour::CipherInit::GetInstance();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PushListSignKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
|
MS_LOG(INFO) << "Launching PushListSignKernel, Iteration number is " << iter_num;
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||||
|
void *req_data = inputs[0]->addr;
|
||||||
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
std::string reason = "FBBuilder builder or req_data is nullptr.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||||
|
if (!verifier.VerifyBuffer<schema::SendClientListSign>()) {
|
||||||
|
std::string reason = "The schema of PushClientListSign is invalid.";
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
const schema::SendClientListSign *client_list_sign_req = flatbuffers::GetRoot<schema::SendClientListSign>(req_data);
|
||||||
|
if (client_list_sign_req == nullptr) {
|
||||||
|
std::string reason = "Building flatbuffers schema failed for PushClientListSign.";
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// verify signature
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
sigVerifyResult verify_result = VerifySignature(client_list_sign_req);
|
||||||
|
if (verify_result == sigVerifyResult::FAILED) {
|
||||||
|
std::string reason = "verify signature failed.";
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||||
|
std::string reason = "verify signature timestamp failed.";
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (verify_result == sigVerifyResult::PASSED) {
|
||||||
|
MS_LOG(INFO) << "verify signature passed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return LaunchForPushListSign(client_list_sign_req, iter_num, fbb, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req,
|
||||||
|
const size_t &iter_num, const std::shared_ptr<server::FBBuilder> &fbb,
|
||||||
|
const std::vector<AddressPtr> &outputs) {
|
||||||
|
size_t iter_client = IntToSize(client_list_sign_req->iteration());
|
||||||
|
if (iter_num != iter_client) {
|
||||||
|
std::string reason = "push list sign iteration number is invalid";
|
||||||
|
MS_LOG(WARNING) << reason;
|
||||||
|
MS_LOG(WARNING) << "server now iteration is " << iter_num << ". client request iteration is " << iter_client;
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||||
|
iter_num);
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::vector<string> update_model_clients;
|
||||||
|
const PBMetadata update_model_clients_pb_out =
|
||||||
|
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||||
|
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
|
||||||
|
for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) {
|
||||||
|
update_model_clients.push_back(update_model_clients_pb.fl_id(i));
|
||||||
|
}
|
||||||
|
std::string fl_id = client_list_sign_req->fl_id()->str();
|
||||||
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
|
MS_LOG(ERROR) << "Current amount for PushListSignKernel is enough.";
|
||||||
|
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
|
||||||
|
// client in get update model client list.
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, "Current amount for PushListSignKernel is enough.",
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
} else {
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||||
|
"Current amount for PushListSignKernel is enough.",
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), iter_num);
|
||||||
|
}
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!PushListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), client_list_sign_req, fbb,
|
||||||
|
update_model_clients)) {
|
||||||
|
MS_LOG(ERROR) << "push client list sign failed.";
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::string count_reason = "";
|
||||||
|
if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) {
|
||||||
|
std::string reason = "Counting for push list sign request failed. Please retry later. " + count_reason;
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()),
|
||||||
|
iter_num);
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
sigVerifyResult PushListSignKernel::VerifySignature(const schema::SendClientListSign *client_list_sign_req) {
|
||||||
|
std::string fl_id = client_list_sign_req->fl_id()->str();
|
||||||
|
std::string timestamp = client_list_sign_req->timestamp()->str();
|
||||||
|
int iteration = client_list_sign_req->iteration();
|
||||||
|
std::string iter_str = std::to_string(iteration);
|
||||||
|
auto fbs_signature = client_list_sign_req->req_signature();
|
||||||
|
std::vector<unsigned char> signature;
|
||||||
|
if (fbs_signature == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "signature in client_list_sign_req is nullptr";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
std::map<std::string, std::string> key_attestations;
|
||||||
|
const fl::PBMetadata &key_attestations_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
|
||||||
|
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
|
||||||
|
auto iter = key_attestation_pb.key_attestations().begin();
|
||||||
|
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
|
||||||
|
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
|
||||||
|
}
|
||||||
|
if (key_attestations.find(fl_id) == key_attestations.end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned char> src_data;
|
||||||
|
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
|
||||||
|
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
|
||||||
|
return sigVerifyResult::PASSED;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PushListSignKernel::PushListSign(const size_t cur_iterator, const std::string &next_req_time,
|
||||||
|
const schema::SendClientListSign *client_list_sign_req,
|
||||||
|
const std::shared_ptr<fl::server::FBBuilder> &fbb,
|
||||||
|
const std::vector<std::string> &update_model_clients) {
|
||||||
|
MS_LOG(INFO) << "CipherMgr::PushClientListSign START";
|
||||||
|
std::vector<std::string> get_client_list; // the clients which get update model client list
|
||||||
|
cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetUpdateModelClientList,
|
||||||
|
&get_client_list);
|
||||||
|
std::string fl_id = client_list_sign_req->fl_id()->str();
|
||||||
|
if (find(get_client_list.begin(), get_client_list.end(), fl_id) == get_client_list.end()) {
|
||||||
|
// client not in get update model client list.
|
||||||
|
std::string reason = "client send signature is not in get update model client list. && client is illegal";
|
||||||
|
MS_LOG(WARNING) << reason;
|
||||||
|
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
|
||||||
|
// client in update model client list, client can move to next round
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator);
|
||||||
|
} else {
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::vector<std::string> send_signs_clients;
|
||||||
|
const fl::PBMetadata &clients_sign_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientListSigns);
|
||||||
|
const fl::ClientListSign &clients_sign_pb = clients_sign_pb_out.client_list_sign();
|
||||||
|
auto iter = clients_sign_pb.client_list_sign().begin();
|
||||||
|
for (; iter != clients_sign_pb.client_list_sign().end(); ++iter) {
|
||||||
|
send_signs_clients.push_back(iter->first);
|
||||||
|
}
|
||||||
|
if (find(send_signs_clients.begin(), send_signs_clients.end(), fl_id) != send_signs_clients.end()) {
|
||||||
|
// the client has sended signature, return false.
|
||||||
|
std::string reason = "The server has received the request, please do not request again.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto fbs_signature = client_list_sign_req->signature();
|
||||||
|
std::vector<char> signature;
|
||||||
|
if (fbs_signature != nullptr) {
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
}
|
||||||
|
fl::PairClientListSign pair_client_list_sign_pb;
|
||||||
|
pair_client_list_sign_pb.set_fl_id(fl_id);
|
||||||
|
pair_client_list_sign_pb.set_signature(signature.data(), signature.size());
|
||||||
|
fl::PBMetadata pb_data;
|
||||||
|
pb_data.mutable_pair_client_list_sign()->MergeFrom(pair_client_list_sign_pb);
|
||||||
|
bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxClientListSigns, pb_data);
|
||||||
|
if (!retcode) {
|
||||||
|
std::string reason = "store client list signature failed";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string reason = "send update model client list signature success. ";
|
||||||
|
BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator);
|
||||||
|
MS_LOG(INFO) << "CipherMgr::PushClientListSign Success";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool PushListSignKernel::Reset() {
|
||||||
|
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
|
MS_LOG(INFO) << "Push list sign kernel reset!";
|
||||||
|
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||||
|
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxClientListSigns);
|
||||||
|
StopTimer();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PushListSignKernel::BuildPushListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb,
|
||||||
|
const schema::ResponseCode retcode, const string &reason,
|
||||||
|
const string &next_req_time, const size_t iteration) {
|
||||||
|
auto rsp_reason = fbb->CreateString(reason);
|
||||||
|
auto rsp_next_req_time = fbb->CreateString(next_req_time);
|
||||||
|
schema::ResponseClientListSignBuilder rsp_builder(*(fbb.get()));
|
||||||
|
rsp_builder.add_retcode(static_cast<int>(retcode));
|
||||||
|
rsp_builder.add_reason(rsp_reason);
|
||||||
|
rsp_builder.add_next_req_time(rsp_next_req_time);
|
||||||
|
rsp_builder.add_iteration(SizeToInt(iteration));
|
||||||
|
auto rsp_push_list_sign = rsp_builder.Finish();
|
||||||
|
fbb->Finish(rsp_push_list_sign);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
REG_ROUND_KERNEL(pushListSign, PushListSignKernel)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_LIST_SIGN_KERNEL_H
|
||||||
|
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_LIST_SIGN_KERNEL_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "fl/server/common.h"
|
||||||
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
|
#include "fl/server/kernel/round/round_kernel_factory.h"
|
||||||
|
#include "fl/armour/cipher/cipher_init.h"
|
||||||
|
#include "fl/server/executor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace server {
|
||||||
|
namespace kernel {
|
||||||
|
// results of signature verification
|
||||||
|
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
|
||||||
|
|
||||||
|
class PushListSignKernel : public RoundKernel {
|
||||||
|
public:
|
||||||
|
PushListSignKernel() = default;
|
||||||
|
~PushListSignKernel() override = default;
|
||||||
|
void InitKernel(size_t required_cnt) override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
bool LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req, const size_t &iter_num,
|
||||||
|
const std::shared_ptr<server::FBBuilder> &fbb, const std::vector<AddressPtr> &outputs);
|
||||||
|
bool Reset() override;
|
||||||
|
void BuildPushListSignKernelRsp(const std::shared_ptr<server::FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
|
const string &reason, const string &next_req_time, const size_t iteration);
|
||||||
|
|
||||||
|
private:
|
||||||
|
armour::CipherInit *cipher_init_;
|
||||||
|
Executor *executor_;
|
||||||
|
size_t iteration_time_window_;
|
||||||
|
sigVerifyResult VerifySignature(const schema::SendClientListSign *client_list_sign_req);
|
||||||
|
bool PushListSign(const size_t cur_iterator, const std::string &next_req_time,
|
||||||
|
const schema::SendClientListSign *client_list_sign_req,
|
||||||
|
const std::shared_ptr<fl::server::FBBuilder> &fbb,
|
||||||
|
const std::vector<std::string> &update_model_clients);
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace server
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_LIST_SIGN_KERNEL_H
|
|
@ -98,7 +98,7 @@ ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
void PushMetricsKernel::BuildPushMetricsRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode) {
|
void PushMetricsKernel::BuildPushMetricsRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode) {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(fbb);
|
MS_ERROR_IF_NULL_WO_RET_VAL(fbb);
|
||||||
schema::ResponsePushMetricsBuilder rsp_push_metrics_builder(*(fbb.get()));
|
schema::ResponsePushMetricsBuilder rsp_push_metrics_builder(*(fbb.get()));
|
||||||
rsp_push_metrics_builder.add_retcode(static_cast<int>(retcode));
|
rsp_push_metrics_builder.add_retcode(retcode);
|
||||||
auto rsp_push_metrics = rsp_push_metrics_builder.Finish();
|
auto rsp_push_metrics = rsp_push_metrics_builder.Finish();
|
||||||
fbb->Finish(rsp_push_metrics);
|
fbb->Finish(rsp_push_metrics);
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,13 @@ void PushWeightKernel::InitKernel(size_t) {
|
||||||
bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool PushWeightKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
|
MS_LOG(INFO) << "Launching PushWeightKernel kernel.";
|
||||||
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, reason.c_str(), reason.size());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void *req_data = inputs[0]->addr;
|
void *req_data = inputs[0]->addr;
|
||||||
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
std::shared_ptr<FBBuilder> fbb = std::make_shared<FBBuilder>();
|
||||||
if (fbb == nullptr || req_data == nullptr) {
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
@ -141,7 +148,7 @@ void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
}
|
}
|
||||||
auto fbs_reason = fbb->CreateString(reason);
|
auto fbs_reason = fbb->CreateString(reason);
|
||||||
schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
|
schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get()));
|
||||||
rsp_push_weight_builder.add_retcode(static_cast<int>(retcode));
|
rsp_push_weight_builder.add_retcode(SizeToInt(retcode));
|
||||||
rsp_push_weight_builder.add_reason(fbs_reason);
|
rsp_push_weight_builder.add_reason(fbs_reason);
|
||||||
rsp_push_weight_builder.add_iteration(SizeToInt(iteration));
|
rsp_push_weight_builder.add_iteration(SizeToInt(iteration));
|
||||||
auto rsp_push_weight = rsp_push_weight_builder.Finish();
|
auto rsp_push_weight = rsp_push_weight_builder.Finish();
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -27,7 +29,6 @@ void ReconstructSecretsKernel::InitKernel(size_t) {
|
||||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
auto last_cnt_handler = [&](std::shared_ptr<ps::core::MessageHandler>) {
|
||||||
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kReconstructSeccrets) {
|
if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kReconstructSeccrets) {
|
||||||
MS_LOG(INFO) << "start FinishIteration";
|
MS_LOG(INFO) << "start FinishIteration";
|
||||||
|
@ -44,14 +45,52 @@ void ReconstructSecretsKernel::InitKernel(size_t) {
|
||||||
{first_cnt_handler, last_cnt_handler});
|
{first_cnt_handler, last_cnt_handler});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sigVerifyResult ReconstructSecretsKernel::VerifySignature(const schema::SendReconstructSecret *reconstruct_secret_req) {
|
||||||
|
std::string fl_id = reconstruct_secret_req->fl_id()->str();
|
||||||
|
std::string timestamp = reconstruct_secret_req->timestamp()->str();
|
||||||
|
int iteration = reconstruct_secret_req->iteration();
|
||||||
|
std::string iter_str = std::to_string(iteration);
|
||||||
|
auto fbs_signature = reconstruct_secret_req->signature();
|
||||||
|
std::vector<unsigned char> signature;
|
||||||
|
if (fbs_signature == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "signature in reconstruct_secret_req is nullptr";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
std::map<std::string, std::string> key_attestations;
|
||||||
|
const fl::PBMetadata &key_attestations_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
|
||||||
|
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
|
||||||
|
auto iter = key_attestation_pb.key_attestations().begin();
|
||||||
|
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
|
||||||
|
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
|
||||||
|
}
|
||||||
|
if (key_attestations.find(fl_id) == key_attestations.end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned char> src_data;
|
||||||
|
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
|
||||||
|
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
|
||||||
|
return sigVerifyResult::PASSED;
|
||||||
|
}
|
||||||
|
|
||||||
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
bool response = false;
|
bool response = false;
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
MS_LOG(INFO) << "Launching ReconstructSecrets Kernel, Iteration number is " << iter_num;
|
||||||
MS_LOG(INFO) << "Iteration number is " << iter_num << ", ReconstructSecretsKernel total duration is "
|
|
||||||
<< total_duration;
|
|
||||||
clock_t start_time = clock();
|
|
||||||
|
|
||||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
|
MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size();
|
||||||
|
@ -73,14 +112,55 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
||||||
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList);
|
||||||
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
|
const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list();
|
||||||
|
|
||||||
for (int i = 0; i < update_model_clients_pb.fl_id_size(); ++i) {
|
for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) {
|
||||||
update_model_clients.push_back(update_model_clients_pb.fl_id(i));
|
update_model_clients.push_back(update_model_clients_pb.fl_id(i));
|
||||||
}
|
}
|
||||||
|
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||||
|
if (!verifier.VerifyBuffer<schema::SendReconstructSecret>()) {
|
||||||
|
std::string reason = "The schema of SendReconstructSecret is invalid.";
|
||||||
|
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num),
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
const schema::SendReconstructSecret *reconstruct_secret_req =
|
const schema::SendReconstructSecret *reconstruct_secret_req =
|
||||||
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
|
flatbuffers::GetRoot<schema::SendReconstructSecret>(req_data);
|
||||||
std::string fl_id = reconstruct_secret_req->fl_id()->str();
|
if (reconstruct_secret_req == nullptr) {
|
||||||
|
std::string reason = "Building flatbuffers schema failed for SendReconstructSecret.";
|
||||||
|
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num),
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// verify signature
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
sigVerifyResult verify_result = VerifySignature(reconstruct_secret_req);
|
||||||
|
if (verify_result == sigVerifyResult::FAILED) {
|
||||||
|
std::string reason = "verify signature failed.";
|
||||||
|
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
SizeToInt(iter_num), std::to_string(CURRENT_TIME_MILLI.count()));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||||
|
std::string reason = "verify signature timestamp failed.";
|
||||||
|
cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason, SizeToInt(iter_num),
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::PASSED) {
|
||||||
|
MS_LOG(INFO) << "verify signature passed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string fl_id = reconstruct_secret_req->fl_id()->str();
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
|
MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough.";
|
||||||
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
|
if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) {
|
||||||
|
@ -106,11 +186,10 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con
|
||||||
MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough.";
|
MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough.";
|
||||||
}
|
}
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
MS_LOG(INFO) << "reconstruct_secrets_kernel success.";
|
||||||
MS_LOG(INFO) << "reconstruct_secrets_kernel success time is : " << duration;
|
|
||||||
if (!response) {
|
if (!response) {
|
||||||
MS_LOG(INFO) << "reconstruct_secrets_kernel response is false.";
|
MS_LOG(INFO) << "reconstruct_secrets_kernel response not ready.";
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,9 @@ namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
// results of signature verification
|
||||||
|
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
|
||||||
|
|
||||||
class ReconstructSecretsKernel : public RoundKernel {
|
class ReconstructSecretsKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
ReconstructSecretsKernel() = default;
|
ReconstructSecretsKernel() = default;
|
||||||
|
@ -43,8 +46,10 @@ class ReconstructSecretsKernel : public RoundKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string name_unmask_;
|
std::string name_unmask_;
|
||||||
|
Executor *executor_;
|
||||||
size_t iteration_time_window_{0};
|
size_t iteration_time_window_{0};
|
||||||
armour::CipherReconStruct cipher_reconstruct_;
|
armour::CipherReconStruct cipher_reconstruct_;
|
||||||
|
sigVerifyResult VerifySignature(const schema::SendReconstructSecret *reconstruct_secret_req);
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -124,7 +124,7 @@ void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, const v
|
||||||
outputs[0]->size = len;
|
outputs[0]->size = len;
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(heap_data_mtx_);
|
std::unique_lock<std::mutex> lock(heap_data_mtx_);
|
||||||
(void)heap_data_.emplace(outputs[0], std::move(output_data));
|
(void)heap_data_.insert(std::make_pair(outputs[0], std::move(output_data)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -26,7 +26,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include "utils/hash_map.h"
|
#include <unordered_map>
|
||||||
#include "backend/kernel_compiler/common_utils.h"
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
@ -120,7 +120,7 @@ class RoundKernel : virtual public CPUKernel {
|
||||||
std::mutex release_mtx_;
|
std::mutex release_mtx_;
|
||||||
std::queue<AddressPtr> heap_data_to_release_;
|
std::queue<AddressPtr> heap_data_to_release_;
|
||||||
std::mutex heap_data_mtx_;
|
std::mutex heap_data_mtx_;
|
||||||
mindspore::HashMap<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
|
std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_;
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -20,8 +20,9 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "utils/hash_map.h"
|
#include <unordered_map>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
#include "fl/server/cert_verify.h"
|
||||||
#include "fl/server/kernel/round/round_kernel.h"
|
#include "fl/server/kernel/round/round_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -42,7 +43,7 @@ class RoundKernelFactory {
|
||||||
RoundKernelFactory(const RoundKernelFactory &) = delete;
|
RoundKernelFactory(const RoundKernelFactory &) = delete;
|
||||||
RoundKernelFactory &operator=(const RoundKernelFactory &) = delete;
|
RoundKernelFactory &operator=(const RoundKernelFactory &) = delete;
|
||||||
|
|
||||||
mindspore::HashMap<std::string, RoundKernelCreator> name_to_creator_map_;
|
std::unordered_map<std::string, RoundKernelCreator> name_to_creator_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class RoundKernelRegister {
|
class RoundKernelRegister {
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
#include "fl/server/kernel/round/share_secrets_kernel.h"
|
#include "fl/server/kernel/round/share_secrets_kernel.h"
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -26,13 +28,6 @@ void ShareSecretsKernel::InitKernel(size_t) {
|
||||||
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) {
|
||||||
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
||||||
}
|
}
|
||||||
|
|
||||||
executor_ = &Executor::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(executor_);
|
|
||||||
if (!executor_->initialized()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
cipher_share_ = &armour::CipherShares::GetInstance();
|
cipher_share_ = &armour::CipherShares::GetInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,21 +44,58 @@ bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr<FBBuilder> &
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sigVerifyResult ShareSecretsKernel::VerifySignature(const schema::RequestShareSecrets *share_secrets_req) {
|
||||||
|
std::string fl_id = share_secrets_req->fl_id()->str();
|
||||||
|
std::string timestamp = share_secrets_req->timestamp()->str();
|
||||||
|
int iteration = share_secrets_req->iteration();
|
||||||
|
std::string iter_str = std::to_string(iteration);
|
||||||
|
auto fbs_signature = share_secrets_req->signature();
|
||||||
|
std::vector<unsigned char> signature;
|
||||||
|
if (fbs_signature == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "signature in share_secrets_req is nullptr";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
std::map<std::string, std::string> key_attestations;
|
||||||
|
const fl::PBMetadata &key_attestations_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
|
||||||
|
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
|
||||||
|
auto iter = key_attestation_pb.key_attestations().begin();
|
||||||
|
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
|
||||||
|
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
|
||||||
|
}
|
||||||
|
if (key_attestations.find(fl_id) == key_attestations.end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned char> src_data;
|
||||||
|
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
|
||||||
|
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
|
||||||
|
return sigVerifyResult::PASSED;
|
||||||
|
}
|
||||||
|
|
||||||
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
bool response = false;
|
bool response = false;
|
||||||
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num();
|
||||||
size_t total_duration = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration);
|
MS_LOG(INFO) << "Launching ShareSecretsKernel, ITERATION NUMBER IS : " << iter_num;
|
||||||
MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is "
|
|
||||||
<< total_duration;
|
|
||||||
clock_t start_time = clock();
|
|
||||||
|
|
||||||
if (inputs.size() != 1 || outputs.size() != 1) {
|
if (inputs.size() != 1 || outputs.size() != 1) {
|
||||||
std::string reason = "inputs or outputs size is invalid.";
|
std::string reason = "inputs or outputs size is invalid.";
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
std::shared_ptr<server::FBBuilder> fbb = std::make_shared<server::FBBuilder>();
|
||||||
void *req_data = inputs[0]->addr;
|
void *req_data = inputs[0]->addr;
|
||||||
if (fbb == nullptr || req_data == nullptr) {
|
if (fbb == nullptr || req_data == nullptr) {
|
||||||
|
@ -71,7 +103,6 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
||||||
MS_LOG(ERROR) << reason;
|
MS_LOG(ERROR) << reason;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
if (DistributedCountService::GetInstance().CountReachThreshold(name_)) {
|
||||||
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
|
MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough.";
|
||||||
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime,
|
||||||
|
@ -80,7 +111,50 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
flatbuffers::Verifier verifier(reinterpret_cast<uint8_t *>(req_data), inputs[0]->size);
|
||||||
|
if (!verifier.VerifyBuffer<schema::RequestShareSecrets>()) {
|
||||||
|
std::string reason = "The schema of RequestShareSecrets is invalid.";
|
||||||
|
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
|
const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot<schema::RequestShareSecrets>(req_data);
|
||||||
|
if (share_secrets_req == nullptr) {
|
||||||
|
std::string reason = "Building flatbuffers schema failed for RequestShareSecrets.";
|
||||||
|
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// verify signature
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
sigVerifyResult verify_result = VerifySignature(share_secrets_req);
|
||||||
|
if (verify_result == sigVerifyResult::FAILED) {
|
||||||
|
std::string reason = "verify signature failed.";
|
||||||
|
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||||
|
std::string reason = "verify signature timestamp failed.";
|
||||||
|
cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason,
|
||||||
|
std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::PASSED) {
|
||||||
|
MS_LOG(INFO) << "verify signature passed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
size_t iter_client = IntToSize(share_secrets_req->iteration());
|
size_t iter_client = IntToSize(share_secrets_req->iteration());
|
||||||
if (iter_num != iter_client) {
|
if (iter_num != iter_client) {
|
||||||
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
|
MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num
|
||||||
|
@ -93,7 +167,7 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
||||||
response = cipher_share_->ShareSecrets(SizeToInt(iter_num), share_secrets_req, fbb,
|
response = cipher_share_->ShareSecrets(SizeToInt(iter_num), share_secrets_req, fbb,
|
||||||
std::to_string(CURRENT_TIME_MILLI.count()));
|
std::to_string(CURRENT_TIME_MILLI.count()));
|
||||||
if (!response) {
|
if (!response) {
|
||||||
MS_LOG(WARNING) << "update secret shares is failed.";
|
MS_LOG(ERROR) << "update secret shares is failed.";
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -102,9 +176,6 @@ bool ShareSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, const std
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
clock_t end_time = clock();
|
|
||||||
double duration = static_cast<double>((end_time - start_time) * 1.0 / CLOCKS_PER_SEC);
|
|
||||||
MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,9 @@ namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
// results of signature verification
|
||||||
|
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
|
||||||
|
|
||||||
class ShareSecretsKernel : public RoundKernel {
|
class ShareSecretsKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
ShareSecretsKernel() = default;
|
ShareSecretsKernel() = default;
|
||||||
|
@ -43,6 +46,7 @@ class ShareSecretsKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_;
|
||||||
armour::CipherShares *cipher_share_;
|
armour::CipherShares *cipher_share_;
|
||||||
|
sigVerifyResult VerifySignature(const schema::RequestShareSecrets *share_secrets_req);
|
||||||
bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req,
|
bool CountForShareSecrets(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestShareSecrets *share_secrets_req,
|
||||||
const size_t iter_num);
|
const size_t iter_num);
|
||||||
};
|
};
|
||||||
|
|
|
@ -19,11 +19,11 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "fl/server/model_store.h"
|
|
||||||
#include "fl/server/iteration.h"
|
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
#include "fl/armour/cipher/cipher_init.h"
|
#include "fl/armour/cipher/cipher_init.h"
|
||||||
#endif
|
#endif
|
||||||
|
#include "fl/server/model_store.h"
|
||||||
|
#include "fl/server/iteration.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
|
@ -46,6 +46,9 @@ void StartFLJobKernel::InitKernel(size_t) {
|
||||||
|
|
||||||
PBMetadata devices_metas;
|
PBMetadata devices_metas;
|
||||||
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas);
|
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas);
|
||||||
|
|
||||||
|
PBMetadata client_key_attestation;
|
||||||
|
DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxClientKeyAttestation, client_key_attestation);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,6 +96,17 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
if (!JudgeFLJobCert(fbb, start_fl_job_req)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!StoreKeyAttestation(fbb, start_fl_job_req)) {
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
|
DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req);
|
||||||
result_code = ReadyForStartFLJob(fbb, device_meta);
|
result_code = ReadyForStartFLJob(fbb, device_meta);
|
||||||
if (result_code != ResultCode::kSuccess) {
|
if (result_code != ResultCode::kSuccess) {
|
||||||
|
@ -122,16 +136,87 @@ bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool StartFLJobKernel::JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
|
const schema::RequestFLJob *start_fl_job_req) {
|
||||||
|
std::string fl_id = start_fl_job_req->fl_id()->str();
|
||||||
|
std::string timestamp = start_fl_job_req->timestamp()->str();
|
||||||
|
auto sign_data_vector = start_fl_job_req->sign_data();
|
||||||
|
if (sign_data_vector == nullptr || sign_data_vector->size() == 0) {
|
||||||
|
std::string reason = "sign data is empty.";
|
||||||
|
BuildStartFLJobRsp(
|
||||||
|
fbb, schema::ResponseCode_RequestError, reason, false,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
unsigned char sign_data[sign_data_vector->size()];
|
||||||
|
|
||||||
|
for (unsigned int i = 0; i < sign_data_vector->size(); i++) {
|
||||||
|
sign_data[i] = sign_data_vector->Get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string key_attestation = start_fl_job_req->key_attestation()->str();
|
||||||
|
std::string equip_cert = start_fl_job_req->equip_cert()->str();
|
||||||
|
std::string equip_ca_cert = start_fl_job_req->equip_ca_cert()->str();
|
||||||
|
std::string root_first_ca_path = ps::PSContext::instance()->root_first_ca_path();
|
||||||
|
std::string root_second_ca_path = ps::PSContext::instance()->root_second_ca_path();
|
||||||
|
std::string equip_crl_path = ps::PSContext::instance()->equip_crl_path();
|
||||||
|
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
bool ret =
|
||||||
|
certVerify.verifyCertAndSign(fl_id, timestamp, (const unsigned char *)sign_data, key_attestation, equip_cert,
|
||||||
|
equip_ca_cert, root_first_ca_path, root_second_ca_path, equip_crl_path);
|
||||||
|
if (!ret) {
|
||||||
|
std::string reason = "startFLJob sign and certificate verify failed.";
|
||||||
|
BuildStartFLJobRsp(
|
||||||
|
fbb, schema::ResponseCode_RequestError, reason, false,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "JudgeFLJobVerify success." << ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool StartFLJobKernel::StoreKeyAttestation(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
|
const schema::RequestFLJob *start_fl_job_req) {
|
||||||
|
// update key attestation
|
||||||
|
if (start_fl_job_req == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
std::string fl_id = start_fl_job_req->fl_id()->str();
|
||||||
|
std::string key_attestation = start_fl_job_req->key_attestation()->str();
|
||||||
|
|
||||||
|
fl::PairKeyAttestation pair_key_attestation_pb;
|
||||||
|
pair_key_attestation_pb.set_fl_id(fl_id);
|
||||||
|
pair_key_attestation_pb.set_certificate(key_attestation);
|
||||||
|
|
||||||
|
fl::PBMetadata pb_data;
|
||||||
|
pb_data.mutable_pair_key_attestation()->MergeFrom(pair_key_attestation_pb);
|
||||||
|
bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxClientKeyAttestation, pb_data);
|
||||||
|
if (!ret) {
|
||||||
|
std::string reason = "startFLJob: store key attestation failed";
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
BuildStartFLJobRsp(
|
||||||
|
fbb, schema::ResponseCode_OutOfTime, reason, false,
|
||||||
|
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool StartFLJobKernel::Reset() {
|
bool StartFLJobKernel::Reset() {
|
||||||
MS_LOG(INFO) << "Starting fl job kernel reset!";
|
MS_LOG(INFO) << "Starting fl job kernel reset!";
|
||||||
StopTimer();
|
StopTimer();
|
||||||
DistributedCountService::GetInstance().ResetCounter(name_);
|
DistributedCountService::GetInstance().ResetCounter(name_);
|
||||||
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas);
|
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas);
|
||||||
|
DistributedMetadataStore::GetInstance().ResetMetadata(kCtxClientKeyAttestation);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
|
void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &) {
|
||||||
iter_next_req_timestamp_ = LongToSize(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
|
iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_;
|
||||||
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_);
|
||||||
// The first startFLJob request means a new iteration starts running.
|
// The first startFLJob request means a new iteration starts running.
|
||||||
Iteration::GetInstance().SetIterationRunning();
|
Iteration::GetInstance().SetIterationRunning();
|
||||||
|
@ -241,9 +326,11 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
fl_plan_builder.add_epochs(SizeToInt(ps::PSContext::instance()->client_epoch_num()));
|
fl_plan_builder.add_epochs(SizeToInt(ps::PSContext::instance()->client_epoch_num()));
|
||||||
fl_plan_builder.add_mini_batch(SizeToInt(ps::PSContext::instance()->client_batch_size()));
|
fl_plan_builder.add_mini_batch(SizeToInt(ps::PSContext::instance()->client_batch_size()));
|
||||||
fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate());
|
fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate());
|
||||||
|
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
fl_plan_builder.add_cipher(cipher_public_params);
|
fl_plan_builder.add_cipher(cipher_public_params);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
auto fbs_fl_plan = fl_plan_builder.Finish();
|
auto fbs_fl_plan = fl_plan_builder.Finish();
|
||||||
|
|
||||||
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps;
|
||||||
|
|
|
@ -58,6 +58,10 @@ class StartFLJobKernel : public RoundKernel {
|
||||||
|
|
||||||
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
|
void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta);
|
||||||
|
|
||||||
|
bool JudgeFLJobCert(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||||
|
|
||||||
|
bool StoreKeyAttestation(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req);
|
||||||
|
|
||||||
// Build response for startFLJob round no matter success or failure.
|
// Build response for startFLJob round no matter success or failure.
|
||||||
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
const std::string &reason, const bool is_selected, const std::string &next_req_time,
|
const std::string &reason, const bool is_selected, const std::string &next_req_time,
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
#include "fl/server/kernel/round/update_model_kernel.h"
|
#include "fl/server/kernel/round/update_model_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -85,6 +86,29 @@ bool UpdateModelKernel::Launch(const std::vector<AddressPtr> &inputs, const std:
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// verify signature
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
sigVerifyResult verify_result = VerifySignature(update_model_req);
|
||||||
|
if (verify_result == sigVerifyResult::FAILED) {
|
||||||
|
std::string reason = "verify signature failed.";
|
||||||
|
BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, "");
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verify_result == sigVerifyResult::TIMEOUT) {
|
||||||
|
std::string reason = "verify signature timestamp failed.";
|
||||||
|
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, "");
|
||||||
|
MS_LOG(ERROR) << reason;
|
||||||
|
GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (verify_result == sigVerifyResult::PASSED) {
|
||||||
|
MS_LOG(INFO) << "verify signature passed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
result_code = UpdateModel(update_model_req, fbb);
|
result_code = UpdateModel(update_model_req, fbb);
|
||||||
if (result_code != ResultCode::kSuccess) {
|
if (result_code != ResultCode::kSuccess) {
|
||||||
MS_LOG(ERROR) << "Updating model failed.";
|
MS_LOG(ERROR) << "Updating model failed.";
|
||||||
|
@ -144,12 +168,11 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
|
MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn);
|
||||||
size_t iteration = IntToSize(update_model_req->iteration());
|
size_t iteration = IntToSize(update_model_req->iteration());
|
||||||
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) {
|
||||||
|
auto next_req_time = LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp);
|
||||||
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) +
|
||||||
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) +
|
", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) +
|
||||||
". Retry later.";
|
". Retry later at time: " + std::to_string(next_req_time);
|
||||||
BuildUpdateModelRsp(
|
BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(next_req_time));
|
||||||
fbb, schema::ResponseCode_OutOfTime, reason,
|
|
||||||
std::to_string(LocalMetaStore::GetInstance().value<uint64_t>(kCtxIterationNextRequestTimestamp)));
|
|
||||||
MS_LOG(WARNING) << reason;
|
MS_LOG(WARNING) << reason;
|
||||||
return ResultCode::kSuccessAndReturn;
|
return ResultCode::kSuccessAndReturn;
|
||||||
}
|
}
|
||||||
|
@ -259,6 +282,47 @@ ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptr<FBBuilde
|
||||||
return ResultCode::kSuccess;
|
return ResultCode::kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sigVerifyResult UpdateModelKernel::VerifySignature(const schema::RequestUpdateModel *update_model_req) {
|
||||||
|
std::string fl_id = update_model_req->fl_id()->str();
|
||||||
|
std::string timestamp = update_model_req->timestamp()->str();
|
||||||
|
int iteration = update_model_req->iteration();
|
||||||
|
std::string iter_str = std::to_string(iteration);
|
||||||
|
auto fbs_signature = update_model_req->signature();
|
||||||
|
std::vector<unsigned char> signature;
|
||||||
|
if (fbs_signature == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "signature in client_list_sign_req is nullptr";
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
signature.assign(fbs_signature->begin(), fbs_signature->end());
|
||||||
|
std::map<std::string, std::string> key_attestations;
|
||||||
|
const fl::PBMetadata &key_attestations_pb_out =
|
||||||
|
fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation);
|
||||||
|
const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation();
|
||||||
|
auto iter = key_attestation_pb.key_attestations().begin();
|
||||||
|
for (; iter != key_attestation_pb.key_attestations().end(); ++iter) {
|
||||||
|
(void)key_attestations.emplace(std::pair<std::string, std::string>(iter->first, iter->second));
|
||||||
|
}
|
||||||
|
if (key_attestations.find(fl_id) == key_attestations.end()) {
|
||||||
|
MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id;
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned char> src_data;
|
||||||
|
(void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end());
|
||||||
|
(void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end());
|
||||||
|
mindspore::ps::server::CertVerify certVerify;
|
||||||
|
unsigned char srcDataHash[SHA256_DIGEST_LENGTH];
|
||||||
|
certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH);
|
||||||
|
if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) {
|
||||||
|
return sigVerifyResult::FAILED;
|
||||||
|
}
|
||||||
|
if (!certVerify.verifyTimeStamp(fl_id, timestamp)) {
|
||||||
|
return sigVerifyResult::TIMEOUT;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success.";
|
||||||
|
return sigVerifyResult::PASSED;
|
||||||
|
}
|
||||||
|
|
||||||
void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
const std::string &reason, const std::string &next_req_time) {
|
const std::string &reason, const std::string &next_req_time) {
|
||||||
if (fbb == nullptr) {
|
if (fbb == nullptr) {
|
||||||
|
|
|
@ -35,10 +35,12 @@ namespace server {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
|
// The initial data size sum of federated learning is 0, which will be accumulated in updateModel round.
|
||||||
constexpr uint64_t kInitialDataSizeSum = 0;
|
constexpr uint64_t kInitialDataSizeSum = 0;
|
||||||
|
// results of signature verification
|
||||||
|
enum sigVerifyResult { FAILED, TIMEOUT, PASSED };
|
||||||
|
|
||||||
class UpdateModelKernel : public RoundKernel {
|
class UpdateModelKernel : public RoundKernel {
|
||||||
public:
|
public:
|
||||||
UpdateModelKernel() : executor_(nullptr), iteration_time_window_(0) {}
|
UpdateModelKernel() = default;
|
||||||
~UpdateModelKernel() override = default;
|
~UpdateModelKernel() override = default;
|
||||||
|
|
||||||
void InitKernel(size_t threshold_count) override;
|
void InitKernel(size_t threshold_count) override;
|
||||||
|
@ -55,6 +57,7 @@ class UpdateModelKernel : public RoundKernel {
|
||||||
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
|
std::map<std::string, UploadData> ParseFeatureMap(const schema::RequestUpdateModel *update_model_req);
|
||||||
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
ResultCode CountForUpdateModel(const std::shared_ptr<FBBuilder> &fbb,
|
||||||
const schema::RequestUpdateModel *update_model_req);
|
const schema::RequestUpdateModel *update_model_req);
|
||||||
|
sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req);
|
||||||
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
void BuildUpdateModelRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode,
|
||||||
const std::string &reason, const std::string &next_req_time);
|
const std::string &reason, const std::string &next_req_time);
|
||||||
|
|
||||||
|
@ -62,7 +65,7 @@ class UpdateModelKernel : public RoundKernel {
|
||||||
Executor *executor_;
|
Executor *executor_;
|
||||||
|
|
||||||
// The time window of one iteration.
|
// The time window of one iteration.
|
||||||
size_t iteration_time_window_;
|
size_t iteration_time_window_{0};
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace server
|
} // namespace server
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "fl/server/kernel/sgd_kernel.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace fl {
|
|
||||||
namespace server {
|
|
||||||
namespace kernel {
|
|
||||||
REG_OPTIMIZER_KERNEL(SGD,
|
|
||||||
ParamsInfo()
|
|
||||||
.AddInputNameType(kWeight, kNumberTypeFloat32)
|
|
||||||
.AddInputNameType(kGradient, kNumberTypeFloat32)
|
|
||||||
.AddInputNameType(kLearningRate, kNumberTypeFloat32)
|
|
||||||
.AddInputNameType(kAccumulation, kNumberTypeFloat32)
|
|
||||||
.AddInputNameType(kMomentum, kNumberTypeFloat32)
|
|
||||||
.AddInputNameType(kStat, kNumberTypeFloat32),
|
|
||||||
SGDKernel, float)
|
|
||||||
} // namespace kernel
|
|
||||||
} // namespace server
|
|
||||||
} // namespace fl
|
|
||||||
} // namespace mindspore
|
|
|
@ -1,63 +0,0 @@
|
||||||
/**
|
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_
|
|
||||||
#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include <memory>
|
|
||||||
#include <utility>
|
|
||||||
#include "backend/kernel_compiler/cpu/sgd_cpu_kernel.h"
|
|
||||||
#include "fl/server/kernel/optimizer_kernel.h"
|
|
||||||
#include "fl/server/kernel/optimizer_kernel_factory.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace fl {
|
|
||||||
namespace server {
|
|
||||||
namespace kernel {
|
|
||||||
using mindspore::kernel::SGDCPUKernel;
|
|
||||||
template <typename T>
|
|
||||||
class SGDKernel : public SGDCPUKernel<T>, public OptimizerKernel {
|
|
||||||
public:
|
|
||||||
SGDKernel() = default;
|
|
||||||
~SGDKernel() override = default;
|
|
||||||
|
|
||||||
void InitKernel(const CNodePtr &cnode) override {
|
|
||||||
SGDCPUKernel<T>::InitKernel(cnode);
|
|
||||||
InitServerKernelInputOutputSize(cnode);
|
|
||||||
GenerateReuseKernelNodeInfo();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
|
||||||
const std::vector<AddressPtr> &outputs) override {
|
|
||||||
return SGDCPUKernel<T>::Launch(inputs, workspace, outputs);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GenerateReuseKernelNodeInfo() override {
|
|
||||||
MS_LOG(INFO) << "SGD reuse 'weight', 'learning rate', 'accumulation', 'momentum' and 'stat' of the kernel node.";
|
|
||||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0));
|
|
||||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2));
|
|
||||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 3));
|
|
||||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kMomentum, 4));
|
|
||||||
reuse_kernel_node_inputs_info_.insert(std::make_pair(kStat, 5));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace kernel
|
|
||||||
} // namespace server
|
|
||||||
} // namespace fl
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_
|
|
|
@ -20,7 +20,7 @@
|
||||||
#include <any>
|
#include <any>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "utils/hash_map.h"
|
#include <unordered_map>
|
||||||
#include "fl/server/common.h"
|
#include "fl/server/common.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -77,10 +77,10 @@ class LocalMetaStore {
|
||||||
LocalMetaStore &operator=(const LocalMetaStore &) = delete;
|
LocalMetaStore &operator=(const LocalMetaStore &) = delete;
|
||||||
|
|
||||||
// key_to_meta_ stores metadata with key-value format.
|
// key_to_meta_ stores metadata with key-value format.
|
||||||
mindspore::HashMap<std::string, std::any> key_to_meta_;
|
std::unordered_map<std::string, std::any> key_to_meta_;
|
||||||
// This mutex makes sure that the operations on key_to_meta_ is threadsafe.
|
// This mutex makes sure that the operations on key_to_meta_ is threadsafe.
|
||||||
std::mutex mtx_;
|
std::mutex mtx_;
|
||||||
size_t curr_iter_num_;
|
size_t curr_iter_num_{0};
|
||||||
};
|
};
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace fl
|
} // namespace fl
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace server {
|
||||||
constexpr size_t kInitIterationNum = 0;
|
constexpr size_t kInitIterationNum = 0;
|
||||||
|
|
||||||
// The initial iteration number after ModelStore is reset.
|
// The initial iteration number after ModelStore is reset.
|
||||||
constexpr size_t kResetInitIterNum = 1;
|
constexpr size_t kResetInitialIterNum = 1;
|
||||||
|
|
||||||
// Server framework use ModelStore to store and query models.
|
// Server framework use ModelStore to store and query models.
|
||||||
// ModelStore stores multiple models because worker could get models of the previous iterations.
|
// ModelStore stores multiple models because worker could get models of the previous iterations.
|
||||||
|
|
|
@ -111,39 +111,6 @@ bool ParameterAggregator::LaunchAggregators() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ParameterAggregator::LaunchOptimizers() {
|
|
||||||
for (auto &optimizer_with_params : optimizer_kernel_parameters_) {
|
|
||||||
KernelParams ¶ms = optimizer_with_params.second;
|
|
||||||
std::shared_ptr<kernel::OptimizerKernel> optimizer_kernel = optimizer_with_params.first;
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
|
|
||||||
bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs);
|
|
||||||
if (!ret) {
|
|
||||||
MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed.";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// As long as all the optimizer kernels are launched, consider optimizing for this ParameterAggregator as done.
|
|
||||||
optimizing_done_ = true;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
AddressPtr ParameterAggregator::Pull() {
|
|
||||||
if (memory_register_ == nullptr) {
|
|
||||||
MS_LOG(ERROR)
|
|
||||||
<< "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first.";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
current_pull_count_++;
|
|
||||||
if (current_pull_count_ == required_pull_count_) {
|
|
||||||
pulling_done_ = true;
|
|
||||||
}
|
|
||||||
MS_LOG(DEBUG) << "The " << current_pull_count_ << " time of Pull. Pulling done status: " << pulling_done_;
|
|
||||||
|
|
||||||
std::map<std::string, AddressPtr> &name_to_addr = memory_register_->addresses();
|
|
||||||
return name_to_addr["weight"];
|
|
||||||
}
|
|
||||||
|
|
||||||
AddressPtr ParameterAggregator::GetWeight() {
|
AddressPtr ParameterAggregator::GetWeight() {
|
||||||
if (memory_register_ == nullptr) {
|
if (memory_register_ == nullptr) {
|
||||||
MS_LOG(ERROR)
|
MS_LOG(ERROR)
|
||||||
|
@ -193,8 +160,8 @@ bool ParameterAggregator::requires_aggr() const { return requires_aggr_; }
|
||||||
|
|
||||||
bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
|
bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (!JudgeRequiresAggr(cnode)) {
|
if (!JudgeRequiredAggr(cnode)) {
|
||||||
MS_LOG(WARNING) << "Aggregation for weight for kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required.";
|
MS_LOG(WARNING) << "Aggregation for weight of kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required.";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
|
std::vector<std::string> aggr_kernel_names = SelectAggregationAlgorithm(cnode);
|
||||||
|
@ -223,32 +190,12 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
|
bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &) {
|
||||||
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
||||||
ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
|
ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) {
|
||||||
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
|
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
const std::string &name = AnfAlgo::GetCNodeName(cnode);
|
|
||||||
auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode);
|
|
||||||
if (optimizer_kernel == nullptr) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failed to create optimizer kernel for " << name;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
optimizer_kernel->InitKernel(cnode);
|
|
||||||
|
|
||||||
const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = optimizer_kernel->reuse_kernel_node_inputs_info();
|
|
||||||
if (!AssignMemory(optimizer_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) {
|
|
||||||
MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!GenerateOptimizerKernelParams(optimizer_kernel, memory_register_)) {
|
|
||||||
MS_LOG(ERROR) << "Generating optimizer kernel parameters failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -323,29 +270,6 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ParameterAggregator::GenerateOptimizerKernelParams(
|
|
||||||
const std::shared_ptr<kernel::OptimizerKernel> &optimizer_kernel,
|
|
||||||
const std::shared_ptr<MemoryRegister> &memory_register) {
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false);
|
|
||||||
KernelParams optimizer_params = {};
|
|
||||||
|
|
||||||
const std::vector<std::string> &input_names = optimizer_kernel->input_names();
|
|
||||||
(void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs),
|
|
||||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
|
||||||
|
|
||||||
const std::vector<std::string> &workspace_names = optimizer_kernel->workspace_names();
|
|
||||||
(void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace),
|
|
||||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
|
||||||
|
|
||||||
const std::vector<std::string> &output_names = optimizer_kernel->output_names();
|
|
||||||
(void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs),
|
|
||||||
[&](const std::string &name) { return memory_register->addresses()[name]; });
|
|
||||||
|
|
||||||
optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params));
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
|
std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) {
|
||||||
std::vector<std::string> aggregation_algorithm = {};
|
std::vector<std::string> aggregation_algorithm = {};
|
||||||
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL ||
|
||||||
|
@ -362,7 +286,7 @@ std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const C
|
||||||
return aggregation_algorithm;
|
return aggregation_algorithm;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) {
|
bool ParameterAggregator::JudgeRequiredAggr(const CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
|
||||||
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
|
if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 ||
|
||||||
|
@ -375,7 +299,7 @@ bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) {
|
||||||
MS_EXCEPTION_IF_NULL(weight_node);
|
MS_EXCEPTION_IF_NULL(weight_node);
|
||||||
|
|
||||||
if (!weight_node->isa<Parameter>()) {
|
if (!weight_node->isa<Parameter>()) {
|
||||||
MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter node.";
|
MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto param_info = weight_node->cast<ParameterPtr>()->param_info();
|
auto param_info = weight_node->cast<ParameterPtr>()->param_info();
|
||||||
|
|
|
@ -77,11 +77,6 @@ class ParameterAggregator {
|
||||||
|
|
||||||
// Launch aggregators/optimizers of this ParameterAggregator in order.
|
// Launch aggregators/optimizers of this ParameterAggregator in order.
|
||||||
bool LaunchAggregators();
|
bool LaunchAggregators();
|
||||||
bool LaunchOptimizers();
|
|
||||||
|
|
||||||
// The implementation for primitive Pull in parameter server training mode.
|
|
||||||
// Every call of this method will increase the count for pull by 1.
|
|
||||||
AddressPtr Pull();
|
|
||||||
|
|
||||||
// Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing
|
// Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing
|
||||||
// any change of status.
|
// any change of status.
|
||||||
|
@ -98,7 +93,6 @@ class ParameterAggregator {
|
||||||
bool IsOptimizingDone() const;
|
bool IsOptimizingDone() const;
|
||||||
bool IsPullingDone() const;
|
bool IsPullingDone() const;
|
||||||
|
|
||||||
// Return whether this parameter requires aggragation.
|
|
||||||
bool requires_aggr() const;
|
bool requires_aggr() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -119,15 +113,13 @@ class ParameterAggregator {
|
||||||
// memory_register.
|
// memory_register.
|
||||||
bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
|
bool GenerateAggregationKernelParams(const std::shared_ptr<kernel::AggregationKernel> &aggr_kernel,
|
||||||
const std::shared_ptr<MemoryRegister> &memory_register);
|
const std::shared_ptr<MemoryRegister> &memory_register);
|
||||||
bool GenerateOptimizerKernelParams(const std::shared_ptr<kernel::OptimizerKernel> &optim_kernel,
|
|
||||||
const std::shared_ptr<MemoryRegister> &memory_register);
|
|
||||||
|
|
||||||
// The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user
|
// The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user
|
||||||
// configuration, etc.
|
// configuration, etc.
|
||||||
std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode);
|
std::vector<std::string> SelectAggregationAlgorithm(const CNodePtr &cnode);
|
||||||
|
|
||||||
// Judge whether the parameter needs to be aggregated.
|
// Judge whether the parameter needs to be aggregated.
|
||||||
bool JudgeRequiresAggr(const CNodePtr &cnode);
|
bool JudgeRequiredAggr(const CNodePtr &cnode);
|
||||||
|
|
||||||
ServerMode server_mode_;
|
ServerMode server_mode_;
|
||||||
size_t required_push_count_;
|
size_t required_push_count_;
|
||||||
|
|
|
@ -38,17 +38,16 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
|
||||||
const FinishIterCb &finish_iteration_cb) {
|
const FinishIterCb &finish_iteration_cb) {
|
||||||
MS_EXCEPTION_IF_NULL(communicator);
|
MS_EXCEPTION_IF_NULL(communicator);
|
||||||
communicator_ = communicator;
|
communicator_ = communicator;
|
||||||
|
MS_LOG(INFO) << "Round " << name_ << " start initialize.";
|
||||||
// Register callback for round kernel.
|
|
||||||
communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
|
communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||||
LaunchRoundKernel(message);
|
LaunchRoundKernel(message);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Callback when the iteration is finished.
|
// Callback when the iteration is finished.
|
||||||
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
|
finish_iteration_cb_ = [this, finish_iteration_cb](bool, const std::string &) -> void {
|
||||||
std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
|
std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
|
||||||
finish_iteration_cb(is_iteration_valid, reason);
|
finish_iteration_cb(true, reason);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Callback for finalizing the server. This can only be called once.
|
// Callback for finalizing the server. This can only be called once.
|
||||||
|
@ -62,9 +61,9 @@ void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &commun
|
||||||
MS_EXCEPTION_IF_NULL(iter_timer_);
|
MS_EXCEPTION_IF_NULL(iter_timer_);
|
||||||
|
|
||||||
// 1.Set the timeout callback for the timer.
|
// 1.Set the timeout callback for the timer.
|
||||||
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void {
|
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool, const std::string &) -> void {
|
||||||
std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration.";
|
std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration.";
|
||||||
timeout_cb(is_iteration_valid, reason);
|
timeout_cb(false, reason);
|
||||||
});
|
});
|
||||||
|
|
||||||
// 2.Stopping timer callback which will be set to the round kernel.
|
// 2.Stopping timer callback which will be set to the round kernel.
|
||||||
|
@ -139,7 +138,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
++Iteration::GetInstance().running_round_num_;
|
(void)(Iteration::GetInstance().running_round_num_++);
|
||||||
AddressPtr input = std::make_shared<Address>();
|
AddressPtr input = std::make_shared<Address>();
|
||||||
AddressPtr output = std::make_shared<Address>();
|
AddressPtr output = std::make_shared<Address>();
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(input);
|
MS_ERROR_IF_NULL_WO_RET_VAL(input);
|
||||||
|
@ -167,7 +166,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &m
|
||||||
reason = "Launching round kernel of round " + name_ + " failed.";
|
reason = "Launching round kernel of round " + name_ + " failed.";
|
||||||
Iteration::GetInstance().NotifyNext(false, reason);
|
Iteration::GetInstance().NotifyNext(false, reason);
|
||||||
}
|
}
|
||||||
--Iteration::GetInstance().running_round_num_;
|
(void)(Iteration::GetInstance().running_round_num_--);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster managers like K8S.
|
// The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster manager like K8S.
|
||||||
std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr;
|
std::shared_ptr<ps::core::CommunicatorBase> g_communicator_with_server = nullptr;
|
||||||
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {};
|
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> g_communicators_with_worker = {};
|
||||||
void SignalHandler(int signal) {
|
void SignalHandler(int signal) {
|
||||||
|
@ -45,7 +45,6 @@ void SignalHandler(int signal) {
|
||||||
|
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server);
|
MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server);
|
||||||
(void)g_communicator_with_server->Stop();
|
(void)g_communicator_with_server->Stop();
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector<RoundConfig> &rounds_config,
|
||||||
|
@ -68,24 +67,10 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each step of the server pipeline may have dependency on other steps, which includes:
|
|
||||||
|
|
||||||
// InitServerContext must be the first step to set contexts for later steps.
|
|
||||||
|
|
||||||
// Server Running relies on URL or Message Type Register:
|
|
||||||
// StartCommunicator---->InitIteration
|
|
||||||
|
|
||||||
// Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
|
|
||||||
// RegisterRoundKernel---->StartCommunicator
|
|
||||||
|
|
||||||
// Kernel Initialization relies on Executor Initialization:
|
|
||||||
// RegisterRoundKernel---->InitExecutor
|
|
||||||
|
|
||||||
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
|
|
||||||
// InitCipher---->InitExecutor
|
|
||||||
void Server::Run() {
|
void Server::Run() {
|
||||||
std::unique_lock<std::mutex> lock(scaling_mtx_);
|
std::unique_lock<std::mutex> lock(scaling_mtx_);
|
||||||
InitServerContext();
|
InitServerContext();
|
||||||
|
InitPkiCertificate();
|
||||||
InitCluster();
|
InitCluster();
|
||||||
InitIteration();
|
InitIteration();
|
||||||
RegisterCommCallbacks();
|
RegisterCommCallbacks();
|
||||||
|
@ -98,6 +83,7 @@ void Server::Run() {
|
||||||
}
|
}
|
||||||
RegisterRoundKernel();
|
RegisterRoundKernel();
|
||||||
InitMetrics();
|
InitMetrics();
|
||||||
|
Recover();
|
||||||
MS_LOG(INFO) << "Server started successfully.";
|
MS_LOG(INFO) << "Server started successfully.";
|
||||||
safemode_ = false;
|
safemode_ = false;
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
|
@ -111,9 +97,27 @@ void Server::Run() {
|
||||||
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
MS_EXCEPTION_IF_NULL(communicator_with_server_);
|
||||||
communicator_with_server_->Join();
|
communicator_with_server_->Join();
|
||||||
MsException::Instance().CheckException();
|
MsException::Instance().CheckException();
|
||||||
|
func_graph_ = nullptr;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Server::InitPkiCertificate() {
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
root_first_ca_path_ = ps::PSContext::instance()->root_first_ca_path();
|
||||||
|
root_second_ca_path_ = ps::PSContext::instance()->root_second_ca_path();
|
||||||
|
equip_crl_path_ = ps::PSContext::instance()->equip_crl_path();
|
||||||
|
replay_attack_time_diff_ = ps::PSContext::instance()->replay_attack_time_diff();
|
||||||
|
|
||||||
|
bool ret = mindspore::ps::server::CertVerify::initRootCertAndCRL(root_first_ca_path_, root_second_ca_path_,
|
||||||
|
equip_crl_path_, replay_attack_time_diff_);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "init root cert and crl failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Server::SwitchToSafeMode() {
|
void Server::SwitchToSafeMode() {
|
||||||
MS_LOG(INFO) << "Server switch to safemode.";
|
MS_LOG(INFO) << "Server switch to safemode.";
|
||||||
safemode_ = true;
|
safemode_ = true;
|
||||||
|
@ -138,11 +142,6 @@ void Server::InitServerContext() {
|
||||||
scheduler_port_ = ps::PSContext::instance()->scheduler_port();
|
scheduler_port_ = ps::PSContext::instance()->scheduler_port();
|
||||||
worker_num_ = ps::PSContext::instance()->initial_worker_num();
|
worker_num_ = ps::PSContext::instance()->initial_worker_num();
|
||||||
server_num_ = ps::PSContext::instance()->initial_server_num();
|
server_num_ = ps::PSContext::instance()->initial_server_num();
|
||||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
|
||||||
if (encrypt_type == ps::kPWEncryptType && server_num_ > 1) {
|
|
||||||
MS_LOG(EXCEPTION) << "Only single server is supported for PW_ENCRYPT now, but got server_num is:." << server_num_;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,6 +217,8 @@ void Server::InitIteration() {
|
||||||
cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
|
cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold;
|
||||||
cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
|
cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold;
|
||||||
cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
|
cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold;
|
||||||
|
cipher_push_list_sign_cnt_ = cipher_config_.push_list_sign_threshold;
|
||||||
|
cipher_get_list_sign_cnt_ = cipher_config_.get_list_sign_threshold;
|
||||||
cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
|
cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold;
|
||||||
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
|
cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1;
|
||||||
cipher_time_window_ = cipher_config_.cipher_time_window;
|
cipher_time_window_ = cipher_config_.cipher_time_window;
|
||||||
|
@ -228,6 +229,8 @@ void Server::InitIteration() {
|
||||||
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
|
<< " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_;
|
||||||
MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
|
MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_
|
||||||
<< " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
|
<< " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_
|
||||||
|
<< " cipher_push_list_sign_cnt_: " << cipher_push_list_sign_cnt_
|
||||||
|
<< " cipher_get_list_sign_cnt_: " << cipher_get_list_sign_cnt_
|
||||||
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
|
<< " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_
|
||||||
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
|
<< " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_
|
||||||
<< " cipher_time_window_: " << cipher_time_window_;
|
<< " cipher_time_window_: " << cipher_time_window_;
|
||||||
|
@ -245,6 +248,7 @@ void Server::InitIteration() {
|
||||||
void Server::InitCipher() {
|
void Server::InitCipher() {
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
cipher_init_ = &armour::CipherInit::GetInstance();
|
cipher_init_ = &armour::CipherInit::GetInstance();
|
||||||
|
|
||||||
int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
|
int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_);
|
||||||
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
|
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
|
||||||
const int cipher_g = 1;
|
const int cipher_g = 1;
|
||||||
|
@ -258,8 +262,7 @@ void Server::InitCipher() {
|
||||||
param.t = cipher_t;
|
param.t = cipher_t;
|
||||||
int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p));
|
int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p));
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
|
MS_LOG(EXCEPTION) << "Memcpy_s error, errorno" << ret;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
param.dp_delta = dp_delta;
|
param.dp_delta = dp_delta;
|
||||||
param.dp_eps = dp_eps;
|
param.dp_eps = dp_eps;
|
||||||
|
@ -281,10 +284,9 @@ void Server::InitCipher() {
|
||||||
if (prim != NULL) {
|
if (prim != NULL) {
|
||||||
BN_clear_free(prim);
|
BN_clear_free(prim);
|
||||||
}
|
}
|
||||||
|
cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
|
||||||
(void)cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_,
|
cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_push_list_sign_cnt_,
|
||||||
cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
|
cipher_get_list_sign_cnt_, cipher_reconstruct_secrets_up_cnt_);
|
||||||
cipher_reconstruct_secrets_up_cnt_);
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -366,12 +368,14 @@ void Server::RegisterMessageCallback(const std::shared_ptr<ps::core::TcpCommunic
|
||||||
std::bind(&Server::HandleNewInstanceRequest, this, std::placeholders::_1));
|
std::bind(&Server::HandleNewInstanceRequest, this, std::placeholders::_1));
|
||||||
communicator->RegisterMsgCallBack("queryInstance",
|
communicator->RegisterMsgCallBack("queryInstance",
|
||||||
std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1));
|
std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1));
|
||||||
|
communicator->RegisterMsgCallBack("syncAfterRecover",
|
||||||
|
std::bind(&Server::HandleSyncAfterRecoveryRequest, this, std::placeholders::_1));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::InitExecutor() {
|
void Server::InitExecutor() {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph_);
|
MS_EXCEPTION_IF_NULL(func_graph_);
|
||||||
if (executor_threshold_ == 0) {
|
if (executor_threshold_ == 0) {
|
||||||
MS_LOG(EXCEPTION) << "The executor's threshold should be greater than 0.";
|
MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// The train engine instance is used in both push-type and pull-type kernels,
|
// The train engine instance is used in both push-type and pull-type kernels,
|
||||||
|
@ -430,7 +434,6 @@ void Server::StartCommunicator() {
|
||||||
MS_LOG(INFO) << "Start communicator with server.";
|
MS_LOG(INFO) << "Start communicator with server.";
|
||||||
if (!communicator_with_server_->Start()) {
|
if (!communicator_with_server_->Start()) {
|
||||||
MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
|
MS_LOG(EXCEPTION) << "Starting communicator with server failed.";
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
DistributedMetadataStore::GetInstance().Initialize(server_node_);
|
DistributedMetadataStore::GetInstance().Initialize(server_node_);
|
||||||
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
|
CollectiveOpsImpl::GetInstance().Initialize(server_node_);
|
||||||
|
@ -447,6 +450,32 @@ void Server::StartCommunicator() {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Server::Recover() {
|
||||||
|
server_recovery_ = std::make_shared<ServerRecovery>();
|
||||||
|
MS_EXCEPTION_IF_NULL(server_recovery_);
|
||||||
|
|
||||||
|
// Try to recovery from persistent storage.
|
||||||
|
if (!server_recovery_->Initialize(ps::PSContext::instance()->config_file_path())) {
|
||||||
|
MS_LOG(WARNING) << "Initializing server recovery failed. Do not recover for this server.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (server_recovery_->Recover()) {
|
||||||
|
// If this server recovers, need to notify cluster to reach consistency.
|
||||||
|
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
|
||||||
|
MS_LOG(INFO) << "Synchronize with leader server after recovery.";
|
||||||
|
if (!server_recovery_->SyncAfterRecovery(tcp_comm, server_node_->rank_id())) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to reach consistency of the cluster after recovery.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the recovery handler to Iteration.
|
||||||
|
MS_EXCEPTION_IF_NULL(iteration_);
|
||||||
|
iteration_->set_recovery_handler(server_recovery_);
|
||||||
|
}
|
||||||
|
|
||||||
void Server::ProcessBeforeScalingOut() {
|
void Server::ProcessBeforeScalingOut() {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
|
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
|
||||||
iteration_->ScalingBarrier();
|
iteration_->ScalingBarrier();
|
||||||
|
@ -510,7 +539,6 @@ void Server::ProcessAfterScalingIn() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
|
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
|
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
|
||||||
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
||||||
|
@ -528,7 +556,6 @@ void Server::HandleEnableServerRequest(const std::shared_ptr<ps::core::MessageHa
|
||||||
}
|
}
|
||||||
|
|
||||||
void Server::HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
void Server::HandleDisableServerRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
|
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
|
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
|
||||||
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
||||||
|
@ -552,7 +579,6 @@ void Server::HandleNewInstanceRequest(const std::shared_ptr<ps::core::MessageHan
|
||||||
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
|
MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
|
||||||
|
|
||||||
MS_ERROR_IF_NULL_WO_RET_VAL(message->data());
|
|
||||||
std::string hyper_params_str(static_cast<const char *>(message->data()), message->len());
|
std::string hyper_params_str(static_cast<const char *>(message->data()), message->len());
|
||||||
nlohmann::json new_instance_json;
|
nlohmann::json new_instance_json;
|
||||||
nlohmann::json response;
|
nlohmann::json response;
|
||||||
|
@ -595,6 +621,31 @@ void Server::HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageH
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Server::HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::MessageHandler> &message) {
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(message);
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(iteration_);
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_);
|
||||||
|
auto tcp_comm = std::dynamic_pointer_cast<ps::core::TcpCommunicator>(communicator_with_server_);
|
||||||
|
MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm);
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Receive SyncAfterRecover request from other server.";
|
||||||
|
std::string response = "success";
|
||||||
|
if (!tcp_comm->SendResponse(response.c_str(), response.size(), message)) {
|
||||||
|
MS_LOG(ERROR) << "Sending response of SyncAfterRecoverRequest failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!safemode_.load()) {
|
||||||
|
MS_LOG(INFO) << "Need to synchronize for other server's recovery";
|
||||||
|
SyncAfterRecover sync_after_recovery_req;
|
||||||
|
(void)sync_after_recovery_req.ParseFromArray(message->data(), SizeToInt(message->len()));
|
||||||
|
if (!iteration_->SyncAfterRecovery(sync_after_recovery_req.current_iter_num())) {
|
||||||
|
MS_LOG(ERROR) << "Sync after recovery failed.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace server
|
} // namespace server
|
||||||
} // namespace fl
|
} // namespace fl
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -24,19 +24,19 @@
|
||||||
#include "ps/core/communicator/tcp_communicator.h"
|
#include "ps/core/communicator/tcp_communicator.h"
|
||||||
#include "ps/core/communicator/task_executor.h"
|
#include "ps/core/communicator/task_executor.h"
|
||||||
#include "ps/core/file_configuration.h"
|
#include "ps/core/file_configuration.h"
|
||||||
#include "fl/server/common.h"
|
|
||||||
#include "fl/server/executor.h"
|
|
||||||
#include "fl/server/iteration.h"
|
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
#include "fl/armour/cipher/cipher_init.h"
|
#include "fl/armour/cipher/cipher_init.h"
|
||||||
#endif
|
#endif
|
||||||
|
#include "fl/server/common.h"
|
||||||
|
#include "fl/server/executor.h"
|
||||||
|
#include "fl/server/iteration.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace fl {
|
namespace fl {
|
||||||
namespace server {
|
namespace server {
|
||||||
// The sleeping time of the server thread before the networking is completed.
|
// The sleeping time of the server thread before the networking is completed.
|
||||||
constexpr uint32_t kServerSleepTimeForNetworking = 1000;
|
constexpr uint32_t kServerSleepTimeForNetworking = 1000;
|
||||||
|
constexpr uint64_t kDefaultReplayAttackTimeDiff = 60000;
|
||||||
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
|
// Class Server is the entrance of MindSpore's parameter server training mode and federated learning.
|
||||||
class Server {
|
class Server {
|
||||||
public:
|
public:
|
||||||
|
@ -51,6 +51,21 @@ class Server {
|
||||||
// According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
|
// According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be
|
||||||
// blocked until the server is finalized.
|
// blocked until the server is finalized.
|
||||||
// func_graph is the frontend graph which will be parse in server's exector and aggregator.
|
// func_graph is the frontend graph which will be parse in server's exector and aggregator.
|
||||||
|
|
||||||
|
// Each step of the server pipeline may have dependency on other steps, which includes:
|
||||||
|
// InitServerContext must be the first step to set contexts for later steps.
|
||||||
|
|
||||||
|
// Server Running relies on URL or Message Type Register:
|
||||||
|
// StartCommunicator---->InitIteration
|
||||||
|
|
||||||
|
// Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion:
|
||||||
|
// RegisterRoundKernel---->StartCommunicator
|
||||||
|
|
||||||
|
// Kernel Initialization relies on Executor Initialization:
|
||||||
|
// RegisterRoundKernel---->InitExecutor
|
||||||
|
|
||||||
|
// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization:
|
||||||
|
// InitCipher---->InitExecutor
|
||||||
void Run();
|
void Run();
|
||||||
|
|
||||||
void SwitchToSafeMode();
|
void SwitchToSafeMode();
|
||||||
|
@ -74,17 +89,25 @@ class Server {
|
||||||
communicators_with_worker_({}),
|
communicators_with_worker_({}),
|
||||||
iteration_(nullptr),
|
iteration_(nullptr),
|
||||||
safemode_(true),
|
safemode_(true),
|
||||||
|
server_recovery_(nullptr),
|
||||||
scheduler_ip_(""),
|
scheduler_ip_(""),
|
||||||
scheduler_port_(0),
|
scheduler_port_(0),
|
||||||
server_num_(0),
|
server_num_(0),
|
||||||
worker_num_(0),
|
worker_num_(0),
|
||||||
fl_server_port_(0),
|
fl_server_port_(0),
|
||||||
|
pki_verify_(false),
|
||||||
|
root_first_ca_path_(""),
|
||||||
|
root_second_ca_path_(""),
|
||||||
|
equip_crl_path_(""),
|
||||||
|
replay_attack_time_diff_(kDefaultReplayAttackTimeDiff),
|
||||||
cipher_initial_client_cnt_(0),
|
cipher_initial_client_cnt_(0),
|
||||||
cipher_exchange_keys_cnt_(0),
|
cipher_exchange_keys_cnt_(0),
|
||||||
cipher_get_keys_cnt_(0),
|
cipher_get_keys_cnt_(0),
|
||||||
cipher_share_secrets_cnt_(0),
|
cipher_share_secrets_cnt_(0),
|
||||||
cipher_get_secrets_cnt_(0),
|
cipher_get_secrets_cnt_(0),
|
||||||
cipher_get_clientlist_cnt_(0),
|
cipher_get_clientlist_cnt_(0),
|
||||||
|
cipher_push_list_sign_cnt_(0),
|
||||||
|
cipher_get_list_sign_cnt_(0),
|
||||||
cipher_reconstruct_secrets_up_cnt_(0),
|
cipher_reconstruct_secrets_up_cnt_(0),
|
||||||
cipher_reconstruct_secrets_down_cnt_(0),
|
cipher_reconstruct_secrets_down_cnt_(0),
|
||||||
cipher_time_window_(0) {}
|
cipher_time_window_(0) {}
|
||||||
|
@ -95,9 +118,6 @@ class Server {
|
||||||
// Load variables which is set by ps_context.
|
// Load variables which is set by ps_context.
|
||||||
void InitServerContext();
|
void InitServerContext();
|
||||||
|
|
||||||
// Try to recover server config from persistent storage.
|
|
||||||
void Recovery();
|
|
||||||
|
|
||||||
// Initialize the server cluster, server node and communicators.
|
// Initialize the server cluster, server node and communicators.
|
||||||
void InitCluster();
|
void InitCluster();
|
||||||
bool InitCommunicatorWithServer();
|
bool InitCommunicatorWithServer();
|
||||||
|
@ -130,6 +150,12 @@ class Server {
|
||||||
// The communicators should be started after all initializations are completed.
|
// The communicators should be started after all initializations are completed.
|
||||||
void StartCommunicator();
|
void StartCommunicator();
|
||||||
|
|
||||||
|
// Try to recover server config from persistent storage.
|
||||||
|
void Recover();
|
||||||
|
|
||||||
|
// load pki huks cbg root certificate and crl
|
||||||
|
void InitPkiCertificate();
|
||||||
|
|
||||||
// The barriers before scaling operations.
|
// The barriers before scaling operations.
|
||||||
void ProcessBeforeScalingOut();
|
void ProcessBeforeScalingOut();
|
||||||
void ProcessBeforeScalingIn();
|
void ProcessBeforeScalingIn();
|
||||||
|
@ -148,6 +174,9 @@ class Server {
|
||||||
// Query current instance information.
|
// Query current instance information.
|
||||||
void HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
void HandleQueryInstanceRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
|
// Synchronize after recovery is completed to ensure consistency.
|
||||||
|
void HandleSyncAfterRecoveryRequest(const std::shared_ptr<ps::core::MessageHandler> &message);
|
||||||
|
|
||||||
// The server node is initialized in Server.
|
// The server node is initialized in Server.
|
||||||
std::shared_ptr<ps::core::ServerNode> server_node_;
|
std::shared_ptr<ps::core::ServerNode> server_node_;
|
||||||
|
|
||||||
|
@ -179,7 +208,7 @@ class Server {
|
||||||
// communicators.
|
// communicators.
|
||||||
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
|
std::vector<std::shared_ptr<ps::core::CommunicatorBase>> communicators_with_worker_;
|
||||||
|
|
||||||
// Mutex for scaling operations. We must wait server's initialization done before handle scaling events.
|
// Mutex for scaling operations.
|
||||||
std::mutex scaling_mtx_;
|
std::mutex scaling_mtx_;
|
||||||
|
|
||||||
// Iteration consists of multiple kinds of rounds.
|
// Iteration consists of multiple kinds of rounds.
|
||||||
|
@ -189,21 +218,33 @@ class Server {
|
||||||
// If true, the server is not available to workers and clients.
|
// If true, the server is not available to workers and clients.
|
||||||
std::atomic_bool safemode_;
|
std::atomic_bool safemode_;
|
||||||
|
|
||||||
|
// The recovery object for server.
|
||||||
|
std::shared_ptr<ServerRecovery> server_recovery_;
|
||||||
|
|
||||||
// Variables set by ps context.
|
// Variables set by ps context.
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
armour::CipherInit *cipher_init_{nullptr};
|
armour::CipherInit *cipher_init_;
|
||||||
#endif
|
#endif
|
||||||
std::string scheduler_ip_;
|
std::string scheduler_ip_;
|
||||||
uint16_t scheduler_port_;
|
uint16_t scheduler_port_;
|
||||||
uint32_t server_num_;
|
uint32_t server_num_;
|
||||||
uint32_t worker_num_;
|
uint32_t worker_num_;
|
||||||
uint16_t fl_server_port_;
|
uint16_t fl_server_port_;
|
||||||
|
bool pki_verify_;
|
||||||
|
|
||||||
|
std::string root_first_ca_path_;
|
||||||
|
std::string root_second_ca_path_;
|
||||||
|
std::string equip_crl_path_;
|
||||||
|
uint64_t replay_attack_time_diff_;
|
||||||
|
|
||||||
size_t cipher_initial_client_cnt_;
|
size_t cipher_initial_client_cnt_;
|
||||||
size_t cipher_exchange_keys_cnt_;
|
size_t cipher_exchange_keys_cnt_;
|
||||||
size_t cipher_get_keys_cnt_;
|
size_t cipher_get_keys_cnt_;
|
||||||
size_t cipher_share_secrets_cnt_;
|
size_t cipher_share_secrets_cnt_;
|
||||||
size_t cipher_get_secrets_cnt_;
|
size_t cipher_get_secrets_cnt_;
|
||||||
size_t cipher_get_clientlist_cnt_;
|
size_t cipher_get_clientlist_cnt_;
|
||||||
|
size_t cipher_push_list_sign_cnt_;
|
||||||
|
size_t cipher_get_list_sign_cnt_;
|
||||||
size_t cipher_reconstruct_secrets_up_cnt_;
|
size_t cipher_reconstruct_secrets_up_cnt_;
|
||||||
size_t cipher_reconstruct_secrets_down_cnt_;
|
size_t cipher_reconstruct_secrets_down_cnt_;
|
||||||
uint64_t cipher_time_window_;
|
uint64_t cipher_time_window_;
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "fl/server/server_recovery.h"
|
||||||
|
#include "fl/server/local_meta_store.h"
|
||||||
|
#include "debug/common.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace server {
|
||||||
|
bool ServerRecovery::Initialize(const std::string &config_file) {
|
||||||
|
config_ = std::make_unique<ps::core::FileConfiguration>(config_file);
|
||||||
|
MS_EXCEPTION_IF_NULL(config_);
|
||||||
|
if (!config_->Initialize()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Initializing for server recovery failed. Config file path " << config_file
|
||||||
|
<< " may be invalid or not exist.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the server recovery file path.
|
||||||
|
if (!config_->Exists(kServerRecovery)) {
|
||||||
|
MS_LOG(WARNING) << "Server recovery config is not set. This node doesn't support recovery.";
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
std::string value = config_->Get(kServerRecovery, "");
|
||||||
|
nlohmann::json value_json;
|
||||||
|
try {
|
||||||
|
value_json = nlohmann::json::parse(value);
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
MS_LOG(EXCEPTION) << "The data is not in json format.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the storage type.
|
||||||
|
uint32_t storage_type = JsonGetKeyWithException<uint32_t>(value_json, ps::kStoreType);
|
||||||
|
if (std::to_string(storage_type) != ps::kFileStorage) {
|
||||||
|
MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse storage file path.
|
||||||
|
server_recovery_file_path_ = JsonGetKeyWithException<std::string>(value_json, ps::kStoreFilePath);
|
||||||
|
MS_LOG(INFO) << "Server recovery file path is " << server_recovery_file_path_;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ServerRecovery::Recover() {
|
||||||
|
server_recovery_file_.open(server_recovery_file_path_, std::ios::in);
|
||||||
|
if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) {
|
||||||
|
MS_LOG(WARNING) << "Can't open server recovery file " << server_recovery_file_path_;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json server_recovery_json;
|
||||||
|
try {
|
||||||
|
server_recovery_json = nlohmann::json::parse(server_recovery_file_);
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
MS_LOG(EXCEPTION) << "The server recovery file is not in json format.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
uint64_t current_iter = JsonGetKeyWithException<uint64_t>(server_recovery_json, kCurrentIteration);
|
||||||
|
LocalMetaStore::GetInstance().set_curr_iter_num(current_iter);
|
||||||
|
MS_LOG(INFO) << "Recover from persistent storage: current iteration number is " << current_iter;
|
||||||
|
server_recovery_file_.close();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ServerRecovery::Save(uint64_t current_iter) {
|
||||||
|
server_recovery_file_.open(server_recovery_file_path_, std::ios::out | std::ios::ate);
|
||||||
|
if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) {
|
||||||
|
MS_LOG(WARNING) << "Can't save data to recovery file " << server_recovery_file_path_
|
||||||
|
<< ". This file path is invalid or does not exit.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json server_metadata_json;
|
||||||
|
server_metadata_json[kCurrentIteration] = current_iter;
|
||||||
|
server_recovery_file_ << server_metadata_json;
|
||||||
|
server_recovery_file_.close();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ServerRecovery::SyncAfterRecovery(const std::shared_ptr<ps::core::TcpCommunicator> &communicator,
|
||||||
|
uint32_t rank_id) {
|
||||||
|
// If this server is follower server, notify leader server that this server has recovered.
|
||||||
|
if (rank_id != kLeaderServerRank) {
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(communicator, false);
|
||||||
|
SyncAfterRecover sync_after_recover_req;
|
||||||
|
sync_after_recover_req.set_current_iter_num(LocalMetaStore::GetInstance().curr_iter_num());
|
||||||
|
if (!communicator->SendPbRequest(sync_after_recover_req, kLeaderServerRank,
|
||||||
|
ps::core::TcpUserCommand::kSyncAfterRecover)) {
|
||||||
|
MS_LOG(ERROR) << "Sending sync after recovery message to leader server failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace server
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,65 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_RECOVERY_H_
|
||||||
|
#define MINDSPORE_CCSRC_FL_SERVER_SERVER_RECOVERY_H_
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ps/core/recovery_base.h"
|
||||||
|
#include "ps/core/file_configuration.h"
|
||||||
|
#include "ps/core/communicator/tcp_communicator.h"
|
||||||
|
#include "ps/ps_context.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace fl {
|
||||||
|
namespace server {
|
||||||
|
constexpr auto kServerRecovery = "server_recovery";
|
||||||
|
|
||||||
|
// The class helps server node to do recovery operation.
|
||||||
|
// Different from the recovery process in ps/core/node_recovery.*, this class focus on recovery of the server data. For
|
||||||
|
// example, current iteration number, learning rate, etc.
|
||||||
|
class ServerRecovery : public ps::core::RecoveryBase {
|
||||||
|
public:
|
||||||
|
ServerRecovery() : config_(nullptr), server_recovery_file_path_("") {}
|
||||||
|
~ServerRecovery() override = default;
|
||||||
|
|
||||||
|
bool Initialize(const std::string &config_file) override;
|
||||||
|
bool Recover() override;
|
||||||
|
|
||||||
|
// Save server's metadata to persistent storage.
|
||||||
|
bool Save(uint64_t current_iter);
|
||||||
|
|
||||||
|
// If this server recovers, need to notify cluster to reach consistency.
|
||||||
|
bool SyncAfterRecovery(const std::shared_ptr<ps::core::TcpCommunicator> &communicator, uint32_t rank_id);
|
||||||
|
|
||||||
|
private:
|
||||||
|
// This is the main config file set by ps context.
|
||||||
|
std::unique_ptr<ps::core::FileConfiguration> config_;
|
||||||
|
|
||||||
|
// The server recovery file path.
|
||||||
|
std::string server_recovery_file_path_;
|
||||||
|
|
||||||
|
// The server recovery file object.
|
||||||
|
std::fstream server_recovery_file_;
|
||||||
|
};
|
||||||
|
} // namespace server
|
||||||
|
} // namespace fl
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_RECOVERY_H_
|
|
@ -220,9 +220,9 @@ void FLWorker::InitializeFollowerScaler() {
|
||||||
std::bind(&FLWorker::ProcessAfterScalingOut, this));
|
std::bind(&FLWorker::ProcessAfterScalingOut, this));
|
||||||
worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
|
worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline",
|
||||||
std::bind(&FLWorker::ProcessAfterScalingIn, this));
|
std::bind(&FLWorker::ProcessAfterScalingIn, this));
|
||||||
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationRunning),
|
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::UserDefineEvent::kIterationRunning),
|
||||||
std::bind(&FLWorker::HandleIterationRunningEvent, this));
|
std::bind(&FLWorker::HandleIterationRunningEvent, this));
|
||||||
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::CustomEvent::kIterationCompleted),
|
worker_node_->RegisterCustomEventCallback(static_cast<uint32_t>(ps::UserDefineEvent::kIterationCompleted),
|
||||||
std::bind(&FLWorker::HandleIterationCompletedEvent, this));
|
std::bind(&FLWorker::HandleIterationCompletedEvent, this));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ constexpr uint32_t kWorkerSleepTimeForNetworking = 1000;
|
||||||
// The time duration between retrying when server is in safemode.
|
// The time duration between retrying when server is in safemode.
|
||||||
constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
|
constexpr uint32_t kWorkerRetryDurationForSafeMode = 500;
|
||||||
|
|
||||||
// The rank of the leader server.
|
// The leader server rank.
|
||||||
constexpr uint32_t kLeaderServerRank = 0;
|
constexpr uint32_t kLeaderServerRank = 0;
|
||||||
|
|
||||||
// The timeout for worker sending message to server in case of network jitter.
|
// The timeout for worker sending message to server in case of network jitter.
|
||||||
|
|
|
@ -879,6 +879,10 @@ bool StartServerAction(const ResourcePtr &res) {
|
||||||
std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold);
|
std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold);
|
||||||
size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)),
|
size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)),
|
||||||
reconstruct_secrets_threshold);
|
reconstruct_secrets_threshold);
|
||||||
|
size_t push_list_sign_threshold = std::max(
|
||||||
|
static_cast<size_t>(std::ceil(client_list_threshold * share_secrets_ratio)), reconstruct_secrets_threshold);
|
||||||
|
size_t get_list_sign_threshold = std::max(
|
||||||
|
static_cast<size_t>(std::ceil(push_list_sign_threshold * share_secrets_ratio)), reconstruct_secrets_threshold);
|
||||||
#ifdef ENABLE_ARMOUR
|
#ifdef ENABLE_ARMOUR
|
||||||
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
|
||||||
if (encrypt_type == ps::kPWEncryptType) {
|
if (encrypt_type == ps::kPWEncryptType) {
|
||||||
|
@ -889,6 +893,10 @@ bool StartServerAction(const ResourcePtr &res) {
|
||||||
rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold});
|
rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold});
|
||||||
rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
|
rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
|
||||||
rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
|
rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
|
||||||
|
if (ps::PSContext::instance()->pki_verify()) {
|
||||||
|
rounds_config.push_back({"pushListSign", true, cipher_time_window, true, push_list_sign_threshold});
|
||||||
|
rounds_config.push_back({"getListSign", true, cipher_time_window, true, get_list_sign_threshold});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (encrypt_type == ps::kStablePWEncryptType) {
|
if (encrypt_type == ps::kStablePWEncryptType) {
|
||||||
MS_LOG(INFO) << "Add stable secure aggregation rounds.";
|
MS_LOG(INFO) << "Add stable secure aggregation rounds.";
|
||||||
|
@ -897,8 +905,9 @@ bool StartServerAction(const ResourcePtr &res) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
fl::server::CipherConfig cipher_config = {
|
fl::server::CipherConfig cipher_config = {
|
||||||
share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,
|
share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,
|
||||||
share_secrets_threshold, get_secrets_threshold, client_list_threshold, reconstruct_secrets_threshold};
|
share_secrets_threshold, get_secrets_threshold, client_list_threshold, push_list_sign_threshold,
|
||||||
|
get_list_sign_threshold, reconstruct_secrets_threshold};
|
||||||
|
|
||||||
size_t executor_threshold = 0;
|
size_t executor_threshold = 0;
|
||||||
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
|
||||||
|
|
|
@ -403,6 +403,7 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
"Set threshold count ratio for share secrets round.")
|
"Set threshold count ratio for share secrets round.")
|
||||||
.def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.")
|
.def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.")
|
||||||
.def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.")
|
.def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.")
|
||||||
|
.def("cipher_time_window", &PSContext::cipher_time_window, "Get time window for cipher rounds.")
|
||||||
.def("set_reconstruct_secrets_threshold", &PSContext::set_reconstruct_secrets_threshold,
|
.def("set_reconstruct_secrets_threshold", &PSContext::set_reconstruct_secrets_threshold,
|
||||||
"Set threshold count for reconstruct secrets round.")
|
"Set threshold count for reconstruct secrets round.")
|
||||||
.def("reconstruct_secrets_threshold", &PSContext::reconstruct_secrets_threshold,
|
.def("reconstruct_secrets_threshold", &PSContext::reconstruct_secrets_threshold,
|
||||||
|
@ -416,16 +417,38 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
|
.def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
|
||||||
.def("client_batch_size", &PSContext::client_batch_size, "Get federated learning client batch size.")
|
.def("client_batch_size", &PSContext::client_batch_size, "Get federated learning client batch size.")
|
||||||
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
|
.def("set_client_learning_rate", &PSContext::set_client_learning_rate,
|
||||||
"Set worker's standalone training step number before communicating with server.")
|
"Set federated learning client learning rate.")
|
||||||
.def("client_learning_rate", &PSContext::client_learning_rate,
|
.def("client_learning_rate", &PSContext::client_learning_rate,
|
||||||
"Get worker's standalone training step number before communicating with server.")
|
"Get worker's standalone training step number before communicating with server.")
|
||||||
.def("set_worker_step_num_per_iteration", &PSContext::set_worker_step_num_per_iteration,
|
.def("set_worker_step_num_per_iteration", &PSContext::set_worker_step_num_per_iteration,
|
||||||
"Set federated learning client learning rate.")
|
"Set worker's standalone training step number before communicating with server..")
|
||||||
.def("worker_step_num_per_iteration", &PSContext::worker_step_num_per_iteration,
|
.def("worker_step_num_per_iteration", &PSContext::worker_step_num_per_iteration,
|
||||||
"Get federated learning client learning rate.")
|
"Get federated learning client learning rate.")
|
||||||
|
.def("set_secure_aggregation", &PSContext::set_secure_aggregation,
|
||||||
|
"Set federated learning client using secure aggregation.")
|
||||||
|
.def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.")
|
||||||
|
.def("dp_eps", &PSContext::dp_eps, "Get dp epsilon for federated learning secure aggregation.")
|
||||||
|
.def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.")
|
||||||
|
.def("dp_delta", &PSContext::dp_delta, "Get dp delta for federated learning secure aggregation.")
|
||||||
|
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
|
||||||
|
"Set dp norm clip for federated learning secure aggregation.")
|
||||||
|
.def("dp_norm_clip", &PSContext::dp_norm_clip, "Get dp norm clip for federated learning secure aggregation.")
|
||||||
|
.def("set_encrypt_type", &PSContext::set_encrypt_type,
|
||||||
|
"Set encrypt type for federated learning secure aggregation.")
|
||||||
|
.def("encrypt_type", &PSContext::encrypt_type, "Get encrypt type for federated learning secure aggregation.")
|
||||||
|
.def("set_root_first_ca_path", &PSContext::set_root_first_ca_path, "Set root first ca path.")
|
||||||
|
.def("root_first_ca_path", &PSContext::root_first_ca_path, "Get root first ca path.")
|
||||||
|
.def("set_root_second_ca_path", &PSContext::set_root_second_ca_path, "Set root second ca path.")
|
||||||
|
.def("root_second_ca_path", &PSContext::root_second_ca_path, "Get root second ca path.")
|
||||||
|
.def("set_pki_verify", &PSContext::set_pki_verify, "Set pki verify.")
|
||||||
|
.def("pki_verify", &PSContext::pki_verify, "Get pki verify.")
|
||||||
.def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port,
|
.def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port,
|
||||||
"Set scheduler manage port used to scale out/in.")
|
"Set scheduler manage port used to scale out/in.")
|
||||||
.def("scheduler_manage_port", &PSContext::scheduler_manage_port, "Get scheduler manage port used to scale out/in.")
|
.def("scheduler_manage_port", &PSContext::scheduler_manage_port, "Get scheduler manage port used to scale out/in.")
|
||||||
|
.def("set_equip_crl_path", &PSContext::set_equip_crl_path, "Set root second crl path.")
|
||||||
|
.def("set_replay_attack_time_diff", &PSContext::set_replay_attack_time_diff, "Set replay attack time diff.")
|
||||||
|
.def("equip_crl_path", &PSContext::equip_crl_path, "Get root second crl path.")
|
||||||
|
.def("replay_attack_time_diff", &PSContext::replay_attack_time_diff, "Get replay attack time diff.")
|
||||||
.def("set_enable_ssl", &PSContext::set_enable_ssl, "Set PS SSL mode enabled or disabled.")
|
.def("set_enable_ssl", &PSContext::set_enable_ssl, "Set PS SSL mode enabled or disabled.")
|
||||||
.def("enable_ssl", &PSContext::enable_ssl, "Get PS SSL mode enabled or disabled.")
|
.def("enable_ssl", &PSContext::enable_ssl, "Get PS SSL mode enabled or disabled.")
|
||||||
.def("set_client_password", &PSContext::set_client_password, "Set the client password to decode the p12 file.")
|
.def("set_client_password", &PSContext::set_client_password, "Set the client password to decode the p12 file.")
|
||||||
|
@ -441,7 +464,9 @@ PYBIND11_MODULE(_c_expression, m) {
|
||||||
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
|
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
|
||||||
"Set dp norm clip for federated learning secure aggregation.")
|
"Set dp norm clip for federated learning secure aggregation.")
|
||||||
.def("set_encrypt_type", &PSContext::set_encrypt_type,
|
.def("set_encrypt_type", &PSContext::set_encrypt_type,
|
||||||
"Set encrypt type for federated learning secure aggregation.");
|
"Set encrypt type for federated learning secure aggregation.")
|
||||||
|
.def("set_http_url_prefix", &PSContext::set_http_url_prefix, "Set http url prefix for http communication.")
|
||||||
|
.def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication.");
|
||||||
|
|
||||||
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
(void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
|
||||||
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
|
(void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
|
||||||
|
|
|
@ -76,28 +76,15 @@ constexpr int64_t kPullCmd = 51;
|
||||||
constexpr size_t kInvalidKey = UINT64_MAX;
|
constexpr size_t kInvalidKey = UINT64_MAX;
|
||||||
constexpr int64_t kInvalidID = -1;
|
constexpr int64_t kInvalidID = -1;
|
||||||
|
|
||||||
constexpr int64_t kGradIndex = 0;
|
|
||||||
constexpr int64_t kIndiceIndex = 1;
|
|
||||||
constexpr int64_t kFirstDimSize = 2;
|
|
||||||
constexpr int64_t kOutDimSize = 3;
|
|
||||||
|
|
||||||
constexpr int64_t kBase = 10;
|
|
||||||
constexpr float kStdDev = 0.01;
|
|
||||||
|
|
||||||
constexpr int64_t kSparseLazyAdamIndex = 2;
|
|
||||||
constexpr int64_t kSparseFtrlIndex = 3;
|
|
||||||
constexpr int64_t kSparseGradIndex = 6;
|
|
||||||
constexpr int64_t kSparseIndiceIndex = 7;
|
|
||||||
|
|
||||||
constexpr int64_t kHeartbeatTimes = 2;
|
|
||||||
constexpr int64_t kGradValue = -100;
|
|
||||||
|
|
||||||
constexpr uint32_t kMaxMessageSize = static_cast<uint32_t>(100 * (uint32_t(1) << 20));
|
constexpr uint32_t kMaxMessageSize = static_cast<uint32_t>(100 * (uint32_t(1) << 20));
|
||||||
constexpr char kServerNum[] = "server_num";
|
constexpr char kServerNum[] = "server_num";
|
||||||
constexpr char kWorkerNum[] = "worker_num";
|
constexpr char kWorkerNum[] = "worker_num";
|
||||||
constexpr char kNodesIds[] = "node_ids";
|
constexpr char kNodesIds[] = "node_ids";
|
||||||
constexpr char kNodeId[] = "node_id";
|
constexpr char kNodeId[] = "node_id";
|
||||||
|
|
||||||
|
constexpr char kSuccessCode[] = "0";
|
||||||
|
constexpr char kErrorCode[] = "1";
|
||||||
|
|
||||||
constexpr int64_t kSubmitTaskIntervalInMs = 1;
|
constexpr int64_t kSubmitTaskIntervalInMs = 1;
|
||||||
constexpr int64_t kMaxTaskNum = 10240;
|
constexpr int64_t kMaxTaskNum = 10240;
|
||||||
constexpr int64_t kSubmitTimeOutInMs = 30000;
|
constexpr int64_t kSubmitTimeOutInMs = 30000;
|
||||||
|
@ -105,7 +92,13 @@ constexpr int64_t kRetryCount = 60;
|
||||||
constexpr int64_t kRetryIntervalInMs = 10;
|
constexpr int64_t kRetryIntervalInMs = 10;
|
||||||
|
|
||||||
constexpr int64_t kThreadNum = 32;
|
constexpr int64_t kThreadNum = 32;
|
||||||
|
constexpr int64_t kGradIndex = 0;
|
||||||
|
constexpr int64_t kIndiceIndex = 1;
|
||||||
|
constexpr int64_t kFirstDimSize = 2;
|
||||||
|
constexpr int64_t kOutDimSize = 3;
|
||||||
|
|
||||||
|
constexpr int64_t kBase = 10;
|
||||||
|
constexpr float kStdDev = 0.01;
|
||||||
// The timeout period for the scale in node to send the finish message to scheduler.
|
// The timeout period for the scale in node to send the finish message to scheduler.
|
||||||
constexpr uint32_t kScaleInTimeoutInSenconds = 30;
|
constexpr uint32_t kScaleInTimeoutInSenconds = 30;
|
||||||
// The number of retries to determine whether all nodes are successfully registered.
|
// The number of retries to determine whether all nodes are successfully registered.
|
||||||
|
@ -113,10 +106,26 @@ constexpr uint32_t kCheckRegisteredRetryCount = 30;
|
||||||
// The timeout interval for judging whether all nodes are successfully registered.
|
// The timeout interval for judging whether all nodes are successfully registered.
|
||||||
constexpr uint32_t kCheckRegisteredIntervalInMs = 1000;
|
constexpr uint32_t kCheckRegisteredIntervalInMs = 1000;
|
||||||
|
|
||||||
|
// The barrier function which should be called before doing scaling out/in operations.
|
||||||
|
// It's easy for us to scale out/in nodes after one iteration is completed and keep consistent.
|
||||||
|
using BarrierBeforeScaleOut = std::function<void(void)>;
|
||||||
|
using BarrierBeforeScaleIn = std::function<void(void)>;
|
||||||
|
|
||||||
|
constexpr int64_t kSparseLazyAdamIndex = 2;
|
||||||
|
constexpr int64_t kSparseFtrlIndex = 3;
|
||||||
|
constexpr int64_t kSparseGradIndex = 6;
|
||||||
|
constexpr int64_t kSparseIndiceIndex = 7;
|
||||||
|
|
||||||
|
constexpr int64_t kHeartbeatTimes = 2;
|
||||||
|
constexpr int64_t kGradValue = -100;
|
||||||
|
// Whether to support recovery.
|
||||||
|
constexpr char kIsRecovery[] = "is_recovery";
|
||||||
// The type of persistent storage, currently only supports file storage.
|
// The type of persistent storage, currently only supports file storage.
|
||||||
constexpr char kStoreType[] = "storage_type";
|
constexpr char kStoreType[] = "storage_type";
|
||||||
// The file used to storage metadata.
|
// The file used to storage metadata.
|
||||||
constexpr char kStoreFilePath[] = "storage_file_path";
|
constexpr char kStoreFilePath[] = "storage_file_path";
|
||||||
|
// The file used to storage scheduler metadata.
|
||||||
|
constexpr char kSchedulerStoreFilePath[] = "scheduler_storage_file_path";
|
||||||
// 1 indicates that the persistent storage type is file.
|
// 1 indicates that the persistent storage type is file.
|
||||||
constexpr char kFileStorage[] = "1";
|
constexpr char kFileStorage[] = "1";
|
||||||
// The recovery key of json_config.
|
// The recovery key of json_config.
|
||||||
|
@ -125,6 +134,10 @@ constexpr char kRecoveryWorkerNum[] = "worker_num";
|
||||||
constexpr char kRecoveryServerNum[] = "server_num";
|
constexpr char kRecoveryServerNum[] = "server_num";
|
||||||
constexpr char kRecoverySchedulerIp[] = "scheduler_ip";
|
constexpr char kRecoverySchedulerIp[] = "scheduler_ip";
|
||||||
constexpr char kRecoverySchedulerPort[] = "scheduler_port";
|
constexpr char kRecoverySchedulerPort[] = "scheduler_port";
|
||||||
|
constexpr char kRecoveryTotalNodeNum[] = "total_node_num";
|
||||||
|
constexpr char kRecoveryNextWorkerRankId[] = "next_worker_rank_id";
|
||||||
|
constexpr char kRecoveryNextServerRankId[] = "next_server_rank_id";
|
||||||
|
constexpr char kRecoveryRegisteredNodesInfos[] = "node_ids";
|
||||||
|
|
||||||
constexpr char kServerCertPath[] = "server_cert_path";
|
constexpr char kServerCertPath[] = "server_cert_path";
|
||||||
constexpr char kServerPassword[] = "server_password";
|
constexpr char kServerPassword[] = "server_password";
|
||||||
|
@ -132,7 +145,6 @@ constexpr char kCrlPath[] = "crl_path";
|
||||||
constexpr char kClientCertPath[] = "client_cert_path";
|
constexpr char kClientCertPath[] = "client_cert_path";
|
||||||
constexpr char kClientPassword[] = "client_password";
|
constexpr char kClientPassword[] = "client_password";
|
||||||
constexpr char kCaCertPath[] = "ca_cert_path";
|
constexpr char kCaCertPath[] = "ca_cert_path";
|
||||||
|
|
||||||
constexpr char kCipherList[] = "cipher_list";
|
constexpr char kCipherList[] = "cipher_list";
|
||||||
constexpr char kCertCheckInterval[] = "cert_check_interval_in_hour";
|
constexpr char kCertCheckInterval[] = "cert_check_interval_in_hour";
|
||||||
// 7 * 24
|
// 7 * 24
|
||||||
|
@ -154,6 +166,7 @@ constexpr int64_t kMaxWarningTime = 180;
|
||||||
|
|
||||||
constexpr int64_t kLength = 100;
|
constexpr int64_t kLength = 100;
|
||||||
constexpr int64_t kMaxPort = 65535;
|
constexpr int64_t kMaxPort = 65535;
|
||||||
|
constexpr int64_t kSecurityLevel = 3;
|
||||||
|
|
||||||
constexpr char kTcpCommunicator[] = "TCP";
|
constexpr char kTcpCommunicator[] = "TCP";
|
||||||
constexpr char kHttpCommunicator[] = "HTTP";
|
constexpr char kHttpCommunicator[] = "HTTP";
|
||||||
|
@ -162,35 +175,14 @@ constexpr char kServerCert[] = "server.p12";
|
||||||
constexpr char kClientCert[] = "client.p12";
|
constexpr char kClientCert[] = "client.p12";
|
||||||
constexpr char kCaCert[] = "ca.crt";
|
constexpr char kCaCert[] = "ca.crt";
|
||||||
constexpr char kColon = ':';
|
constexpr char kColon = ':';
|
||||||
const std::map<std::string, size_t> kCiphers = {{"ECDHE-RSA-AES128-GCM-SHA256", 0},
|
const std::map<std::string, size_t> kCiphers = {
|
||||||
{"ECDHE-ECDSA-AES128-GCM-SHA256", 1},
|
{"ECDHE-RSA-AES128-GCM-SHA256", 0}, {"ECDHE-ECDSA-AES128-GCM-SHA256", 1}, {"ECDHE-RSA-AES256-GCM-SHA384", 2},
|
||||||
{"ECDHE-RSA-AES256-GCM-SHA384", 2},
|
{"ECDHE-ECDSA-AES256-GCM-SHA384", 3}, {"DHE-RSA-AES128-GCM-SHA256", 4}, {"DHE-DSS-AES128-GCM-SHA256", 5},
|
||||||
{"ECDHE-ECDSA-AES256-GCM-SHA384", 3},
|
{"DHE-RSA-AES256-GCM-SHA384", 6}, {"DHE-DSS-AES256-GCM-SHA384", 7}, {"DHE-PSK-AES128-GCM-SHA256", 8},
|
||||||
{"DHE-RSA-AES128-GCM-SHA256", 4},
|
{"DHE-PSK-AES256-GCM-SHA384", 9}, {"DHE-PSK-CHACHA20-POLY1305", 10}, {"ECDHE-RSA-CHACHA20-POLY1305", 11},
|
||||||
{"DHE-DSS-AES128-GCM-SHA256", 5},
|
{"ECDHE-PSK-CHACHA20-POLY1305", 12}, {"DHE-RSA-AES128-CCM", 13}, {"DHE-RSA-AES256-CCM", 14},
|
||||||
{"ECDHE-RSA-AES128-SHA256", 6},
|
{"DHE-RSA-CHACHA20-POLY1305", 15}, {"DHE-PSK-AES128-CCM", 16}, {"DHE-PSK-AES256-CCM", 17},
|
||||||
{"ECDHE-ECDSA-AES128-SHA256", 7},
|
{"ECDHE-ECDSA-AES128-CCM", 18}, {"ECDHE-ECDSA-AES256-CCM", 19}, {"ECDHE-ECDSA-CHACHA20-POLY1305", 20}};
|
||||||
{"ECDHE-RSA-AES128-SHA", 8},
|
|
||||||
{"ECDHE-ECDSA-AES128-SHA", 9},
|
|
||||||
{"ECDHE-RSA-AES256-SHA384", 10},
|
|
||||||
{"ECDHE-ECDSA-AES256-SHA384", 11},
|
|
||||||
{"ECDHE-RSA-AES256-SHA", 12},
|
|
||||||
{"ECDHE-ECDSA-AES256-SHA", 13},
|
|
||||||
{"DHE-RSA-AES128-SHA256", 14},
|
|
||||||
{"DHE-RSA-AES128-SHA", 15},
|
|
||||||
{"DHE-DSS-AES128-SHA256", 16},
|
|
||||||
{"DHE-RSA-AES256-SHA256", 17},
|
|
||||||
{"DHE-DSS-AES256-SHA", 18},
|
|
||||||
{"DHE-RSA-AES256-SHA", 19},
|
|
||||||
{"!aNULL", 20},
|
|
||||||
{"!eNULL", 21},
|
|
||||||
{"!EXPORT", 22},
|
|
||||||
{"!DES", 23},
|
|
||||||
{"!RC4", 24},
|
|
||||||
{"!3DES", 25},
|
|
||||||
{"!MD5", 26},
|
|
||||||
{"!PSK", 27},
|
|
||||||
{"kEDH+AESGCM", 28}};
|
|
||||||
|
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
using DataPtr = std::shared_ptr<unsigned char>;
|
using DataPtr = std::shared_ptr<unsigned char>;
|
||||||
|
@ -262,7 +254,7 @@ using HandlerAfterScaleIn = std::function<void(void)>;
|
||||||
constexpr char kClusterSafeMode[] = "The cluster is in safemode.";
|
constexpr char kClusterSafeMode[] = "The cluster is in safemode.";
|
||||||
constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished.";
|
constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished.";
|
||||||
|
|
||||||
enum class CustomEvent { kIterationRunning = 0, kIterationCompleted };
|
enum class UserDefineEvent { kIterationRunning = 0, kIterationCompleted, kNodeTimeout };
|
||||||
|
|
||||||
#define EXC_IF_VEC_IDX_OOB(vec, idx) \
|
#define EXC_IF_VEC_IDX_OOB(vec, idx) \
|
||||||
{ \
|
{ \
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -50,9 +50,18 @@ void AbstractNode::ProcessRegisterResp(const std::shared_ptr<MessageMeta> &meta,
|
||||||
MS_EXCEPTION_IF_NULL(data);
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
RegisterRespMessage register_resp_message;
|
RegisterRespMessage register_resp_message;
|
||||||
CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size)));
|
CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size)));
|
||||||
|
MS_LOG(INFO) << "The node id get from scheduler is:" << register_resp_message.node_id()
|
||||||
|
<< ", rank_id is:" << register_resp_message.rank_id();
|
||||||
|
|
||||||
if (register_resp_message.node_id() != node_info_.node_id_) {
|
if (register_resp_message.node_id() != node_info_.node_id_) {
|
||||||
MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id()
|
MS_LOG(ERROR) << "The node id received:" << register_resp_message.node_id()
|
||||||
<< " is not match the current node id:" << node_info_.node_id_;
|
<< " is not match the current node id:" << node_info_.node_id_;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
node_info_.rank_id_ = register_resp_message.rank_id();
|
||||||
|
if (node_info_.rank_id_ == UINT32_MAX) {
|
||||||
|
MS_LOG(ERROR) << "The rank id received:" << register_resp_message.rank_id();
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Receive the Register message, indicating that the scheduler is alive, so update the time point at which the
|
// Receive the Register message, indicating that the scheduler is alive, so update the time point at which the
|
||||||
|
@ -70,7 +79,7 @@ bool AbstractNode::Broadcast(const NodeRole &node_role, const DataPtr &message,
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t broadcast_size = 0;
|
uint32_t broadcast_size = 0;
|
||||||
std::for_each(nodes_address_.begin(), nodes_address_.end(), [&broadcast_size, &node_role](const auto &addr) {
|
(void)std::for_each(nodes_address_.begin(), nodes_address_.end(), [&broadcast_size, &node_role](const auto &addr) {
|
||||||
if (addr.first.first == node_role) {
|
if (addr.first.first == node_role) {
|
||||||
++broadcast_size;
|
++broadcast_size;
|
||||||
}
|
}
|
||||||
|
@ -160,19 +169,24 @@ void AbstractNode::BroadcastEvent(const uint32_t &event) {
|
||||||
MS_EXCEPTION_IF_NULL(message_meta);
|
MS_EXCEPTION_IF_NULL(message_meta);
|
||||||
message_meta->set_cmd(NodeCommand::SEND_EVENT);
|
message_meta->set_cmd(NodeCommand::SEND_EVENT);
|
||||||
|
|
||||||
EventMessage event_message;
|
EventRespMessage event_resp_message;
|
||||||
event_message.set_event(event);
|
event_resp_message.set_event(event);
|
||||||
event_message.set_node_id(node_info_.node_id_);
|
|
||||||
|
|
||||||
if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF, event_message.SerializeAsString().data(),
|
for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) {
|
||||||
event_message.ByteSizeLong())) {
|
const uint32_t rank_id = (*it).first.second;
|
||||||
MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
const NodeRole role = (*it).first.first;
|
||||||
<< " the node id:" << node_info_.node_id_ << " send event timeout!";
|
auto client = GetOrCreateTcpClient(rank_id, role);
|
||||||
return;
|
if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(),
|
||||||
|
event_resp_message.ByteSizeLong())) {
|
||||||
|
MS_LOG(ERROR) << "send event to node role:" << CommUtil::NodeRoleToString(role) << ", rank id:" << rank_id
|
||||||
|
<< " timeout!";
|
||||||
|
} else {
|
||||||
|
MS_LOG(INFO) << "send event to node role:" << CommUtil::NodeRoleToString(role) << ", rank id:" << rank_id
|
||||||
|
<< " successful!";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||||
<< " the node id:" << node_info_.node_id_ << "is send event to scheduler!";
|
<< " the node id:" << node_info_.node_id_ << "is send event to server/worker!";
|
||||||
}
|
}
|
||||||
|
|
||||||
void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) {
|
void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) {
|
||||||
|
@ -185,10 +199,6 @@ void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const Even
|
||||||
|
|
||||||
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
|
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len,
|
||||||
int command, const uint32_t &timeout) {
|
int command, const uint32_t &timeout) {
|
||||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
|
||||||
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(data);
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||||
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
|
MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_
|
||||||
|
@ -209,11 +219,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, cons
|
||||||
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||||
const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command,
|
const std::vector<DataPtr> &data, const std::vector<size_t> &lens, int command,
|
||||||
const uint32_t &timeout) {
|
const uint32_t &timeout) {
|
||||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
|
||||||
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t request_id = AddMessageTrack(data.size());
|
uint64_t request_id = AddMessageTrack(data.size());
|
||||||
|
|
||||||
if (rank_ids.size() != data.size() || rank_ids.size() != lens.size()) {
|
if (rank_ids.size() != data.size() || rank_ids.size() != lens.size()) {
|
||||||
|
@ -248,10 +253,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &
|
||||||
|
|
||||||
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
|
bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len,
|
||||||
int command, VectorPtr *output, const uint32_t &timeout) {
|
int command, VectorPtr *output, const uint32_t &timeout) {
|
||||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
|
||||||
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(message);
|
MS_EXCEPTION_IF_NULL(message);
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) {
|
||||||
|
@ -289,10 +290,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, cons
|
||||||
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
bool AbstractNode::Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids,
|
||||||
const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command,
|
const std::vector<DataPtr> &data, const std::vector<size_t> &data_lens, int command,
|
||||||
std::vector<VectorPtr> *output, const uint32_t &timeout) {
|
std::vector<VectorPtr> *output, const uint32_t &timeout) {
|
||||||
if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) {
|
|
||||||
MS_LOG(DEBUG) << "The node is timeout, can not send message.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_EXCEPTION_IF_NULL(output);
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
uint64_t request_id = AddMessageTrack(data.size());
|
uint64_t request_id = AddMessageTrack(data.size());
|
||||||
|
|
||||||
|
@ -493,10 +490,6 @@ std::shared_ptr<CommunicatorBase> AbstractNode::GetOrCreateTcpComm(const std::st
|
||||||
if (!communicators_.count(kTcpCommunicator)) {
|
if (!communicators_.count(kTcpCommunicator)) {
|
||||||
MS_LOG(INFO) << "Create Tcp communicator.";
|
MS_LOG(INFO) << "Create Tcp communicator.";
|
||||||
auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this);
|
auto tcp_comm = std::make_shared<TcpCommunicator>(task_executor, this);
|
||||||
PSContext::instance()->cluster_config().scheduler_host = scheduler_ip;
|
|
||||||
PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port);
|
|
||||||
PSContext::instance()->cluster_config().initial_worker_num = worker_num;
|
|
||||||
PSContext::instance()->cluster_config().initial_server_num = server_num;
|
|
||||||
MS_EXCEPTION_IF_NULL(tcp_comm);
|
MS_EXCEPTION_IF_NULL(tcp_comm);
|
||||||
PSContext::instance()->cluster_config().scheduler_host = scheduler_ip;
|
PSContext::instance()->cluster_config().scheduler_host = scheduler_ip;
|
||||||
PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port);
|
PSContext::instance()->cluster_config().scheduler_port = static_cast<uint16_t>(scheduler_port);
|
||||||
|
@ -521,13 +514,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr<TcpClient> &client)
|
||||||
MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||||
<< ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
<< ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
||||||
if (CheckSchedulerTimeout()) {
|
if (CheckSchedulerTimeout()) {
|
||||||
MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
MS_LOG(WARNING) << "Scheduler is Timeout, please recovery.";
|
||||||
<< ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!";
|
|
||||||
is_finish_ = true;
|
|
||||||
wait_finish_cond_.notify_all();
|
|
||||||
if (!is_already_stopped_) {
|
|
||||||
OnEventCallback(ClusterEvent::SCHEDULER_TIMEOUT);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
UpdateSchedulerTime();
|
UpdateSchedulerTime();
|
||||||
|
@ -549,6 +536,7 @@ bool AbstractNode::Heartbeat(const std::shared_ptr<TcpClient> &client) {
|
||||||
HeartbeatMessage heartbeat_message;
|
HeartbeatMessage heartbeat_message;
|
||||||
heartbeat_message.set_node_id(node_info_.node_id_);
|
heartbeat_message.set_node_id(node_info_.node_id_);
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " Send heartbeat!";
|
||||||
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
|
if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(),
|
||||||
heartbeat_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
|
heartbeat_message.ByteSizeLong(), kCommTimeoutInSeconds)) {
|
||||||
MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!";
|
||||||
|
@ -580,20 +568,31 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
|
||||||
HeartbeatRespMessage heartbeat_resp_message;
|
HeartbeatRespMessage heartbeat_resp_message;
|
||||||
CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
|
CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size)));
|
||||||
|
|
||||||
|
if (heartbeat_resp_message.cluster_state() != current_cluster_state_) {
|
||||||
|
MS_LOG(INFO) << "cluster change state from:" << CommUtil::ClusterStateToString(current_cluster_state_) << " to "
|
||||||
|
<< CommUtil::ClusterStateToString(heartbeat_resp_message.cluster_state());
|
||||||
|
}
|
||||||
|
|
||||||
current_cluster_state_ = heartbeat_resp_message.cluster_state();
|
current_cluster_state_ = heartbeat_resp_message.cluster_state();
|
||||||
MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
|
MS_LOG(DEBUG) << "The current cluster state from heartbeat:"
|
||||||
<< CommUtil::ClusterStateToString(current_cluster_state_);
|
<< CommUtil::ClusterStateToString(current_cluster_state_);
|
||||||
|
|
||||||
|
std::string timeoutNodeId;
|
||||||
|
|
||||||
all_nodes_info_.clear();
|
all_nodes_info_.clear();
|
||||||
for (const auto &it : heartbeat_resp_message.servers_meta()) {
|
for (const auto &it : heartbeat_resp_message.servers_meta()) {
|
||||||
NodeInfo info;
|
NodeInfo info;
|
||||||
info.ip_ = it.ip();
|
info.ip_ = it.ip();
|
||||||
info.node_id_ = it.node_id();
|
info.node_id_ = it.node_id();
|
||||||
info.port_ = static_cast<uint16_t>(it.port());
|
info.port_ = it.port();
|
||||||
info.node_role_ = it.role();
|
info.node_role_ = it.role();
|
||||||
info.rank_id_ = it.rank_id();
|
info.rank_id_ = it.rank_id();
|
||||||
info.is_alive = it.is_alive();
|
info.is_alive = it.is_alive();
|
||||||
|
|
||||||
|
if (!info.is_alive) {
|
||||||
|
timeoutNodeId += (info.node_id_ + " ");
|
||||||
|
}
|
||||||
|
|
||||||
all_nodes_info_[info.node_id_] = info;
|
all_nodes_info_[info.node_id_] = info;
|
||||||
MS_LOG(DEBUG) << "The node id:" << info.node_id_ << ", the rank id:" << info.rank_id_
|
MS_LOG(DEBUG) << "The node id:" << info.node_id_ << ", the rank id:" << info.rank_id_
|
||||||
<< ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive;
|
<< ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive;
|
||||||
|
@ -608,7 +607,8 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr<MessageMeta> &meta
|
||||||
wait_start_cond_.notify_all();
|
wait_start_cond_.notify_all();
|
||||||
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
|
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(INFO) << "The node is support recovery, users can pull up this node to restore the cluster.";
|
MS_LOG(INFO) << "The nodes:" << timeoutNodeId
|
||||||
|
<< "is support recovery, users can pull up this node to restore the cluster.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -653,14 +653,14 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SendMetadataMessage send_meta_message;
|
SendMetadataMessage send_meta_message;
|
||||||
CHECK_RETURN_TYPE(send_meta_message.ParseFromArray(data, SizeToInt(size)));
|
send_meta_message.ParseFromArray(data, SizeToInt(size));
|
||||||
worker_num_ = send_meta_message.worker_num();
|
worker_num_ = send_meta_message.worker_num();
|
||||||
server_num_ = send_meta_message.server_num();
|
server_num_ = send_meta_message.server_num();
|
||||||
if (send_meta_message.rank_id() < 0) {
|
if (send_meta_message.rank_id() < 0) {
|
||||||
MS_LOG(EXCEPTION) << "The rank id is wrong.";
|
MS_LOG(EXCEPTION) << "The rank id is wrong.";
|
||||||
}
|
}
|
||||||
node_info_.rank_id_ = send_meta_message.rank_id();
|
node_info_.rank_id_ = send_meta_message.rank_id();
|
||||||
current_cluster_state_ = send_meta_message.cluster_state();
|
UpdateClusterState(send_meta_message.cluster_state());
|
||||||
MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_
|
MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_
|
||||||
<< ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_)
|
<< ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_)
|
||||||
<< ", the rank id:" << node_info_.rank_id_;
|
<< ", the rank id:" << node_info_.rank_id_;
|
||||||
|
@ -669,7 +669,8 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
|
||||||
nodes_address_.clear();
|
nodes_address_.clear();
|
||||||
for (const auto &it : send_meta_message.servers_meta()) {
|
for (const auto &it : send_meta_message.servers_meta()) {
|
||||||
nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port());
|
nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port());
|
||||||
MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port() << ", the rank id:" << it.rank_id();
|
MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(it.role()) << ", node id:" << it.node_id()
|
||||||
|
<< ", rank id:" << it.rank_id() << ", ip:" << it.ip() << ", port:" << it.port();
|
||||||
}
|
}
|
||||||
client_mutex_.unlock();
|
client_mutex_.unlock();
|
||||||
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
|
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
|
||||||
|
@ -690,6 +691,7 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr<TcpConnection> &con
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(client_mutex_);
|
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||||
connected_nodes_.clear();
|
connected_nodes_.clear();
|
||||||
|
PersistMetaData();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
void AbstractNode::ProcessFinish(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||||
|
@ -710,11 +712,13 @@ void AbstractNode::ProcessScaleOutDone(const std::shared_ptr<TcpConnection> &con
|
||||||
MS_EXCEPTION_IF_NULL(conn);
|
MS_EXCEPTION_IF_NULL(conn);
|
||||||
MS_EXCEPTION_IF_NULL(meta);
|
MS_EXCEPTION_IF_NULL(meta);
|
||||||
MS_EXCEPTION_IF_NULL(data);
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
|
MS_LOG(INFO) << "This node receive a scale out done from scheduler.";
|
||||||
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
|
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
|
||||||
MS_LOG(WARNING) << "Server response message failed.";
|
MS_LOG(WARNING) << "Server response message failed.";
|
||||||
}
|
}
|
||||||
is_ready_ = true;
|
is_ready_ = true;
|
||||||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||||
|
PersistMetaData();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
|
void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn,
|
||||||
|
@ -727,7 +731,8 @@ void AbstractNode::ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn
|
||||||
MS_LOG(WARNING) << "Server response message failed.";
|
MS_LOG(WARNING) << "Server response message failed.";
|
||||||
}
|
}
|
||||||
is_ready_ = true;
|
is_ready_ = true;
|
||||||
current_cluster_state_ = ClusterState::CLUSTER_READY;
|
UpdateClusterState(ClusterState::CLUSTER_READY);
|
||||||
|
PersistMetaData();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||||
|
@ -736,12 +741,17 @@ void AbstractNode::ProcessEvent(const std::shared_ptr<TcpConnection> &conn, cons
|
||||||
MS_EXCEPTION_IF_NULL(meta);
|
MS_EXCEPTION_IF_NULL(meta);
|
||||||
MS_EXCEPTION_IF_NULL(data);
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
EventRespMessage event_resp_message;
|
EventRespMessage event_resp_message;
|
||||||
CHECK_RETURN_TYPE(event_resp_message.ParseFromArray(data, SizeToInt(size)));
|
event_resp_message.ParseFromArray(data, SizeToInt(size));
|
||||||
uint32_t event = event_resp_message.event();
|
uint32_t event = event_resp_message.event();
|
||||||
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
|
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
|
||||||
MS_LOG(WARNING) << "Server response message failed.";
|
MS_LOG(WARNING) << "Server response message failed.";
|
||||||
}
|
}
|
||||||
OnCustomEventCallback(event);
|
MS_LOG(INFO) << "This node receive a event:" << event;
|
||||||
|
if (event == static_cast<uint32_t>(ps::UserDefineEvent::kNodeTimeout)) {
|
||||||
|
OnEventCallback(ClusterEvent::NODE_TIMEOUT);
|
||||||
|
} else {
|
||||||
|
OnCustomEventCallback(event);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||||
|
@ -751,7 +761,7 @@ void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, c
|
||||||
MS_EXCEPTION_IF_NULL(data);
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
|
|
||||||
ScaleOutMessage scale_out_message;
|
ScaleOutMessage scale_out_message;
|
||||||
CHECK_RETURN_TYPE(scale_out_message.ParseFromArray(data, SizeToInt(size)));
|
scale_out_message.ParseFromArray(data, SizeToInt(size));
|
||||||
int32_t worker_num = scale_out_message.worker_num();
|
int32_t worker_num = scale_out_message.worker_num();
|
||||||
int32_t server_num = scale_out_message.server_num();
|
int32_t server_num = scale_out_message.server_num();
|
||||||
MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num;
|
MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num;
|
||||||
|
@ -760,7 +770,7 @@ void AbstractNode::ProcessScaleOut(const std::shared_ptr<TcpConnection> &conn, c
|
||||||
MS_LOG(WARNING) << "Server response message failed.";
|
MS_LOG(WARNING) << "Server response message failed.";
|
||||||
}
|
}
|
||||||
OnEventCallback(ClusterEvent::READY_FOR_SCALE_OUT);
|
OnEventCallback(ClusterEvent::READY_FOR_SCALE_OUT);
|
||||||
current_cluster_state_ = ClusterState::CLUSTER_SCALE_OUT;
|
UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT);
|
||||||
is_ready_ = false;
|
is_ready_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -771,7 +781,7 @@ void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, co
|
||||||
MS_EXCEPTION_IF_NULL(data);
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
|
|
||||||
ScaleInMessage scale_in_message;
|
ScaleInMessage scale_in_message;
|
||||||
CHECK_RETURN_TYPE(scale_in_message.ParseFromArray(data, SizeToInt(size)));
|
scale_in_message.ParseFromArray(data, SizeToInt(size));
|
||||||
int32_t worker_num = scale_in_message.worker_num();
|
int32_t worker_num = scale_in_message.worker_num();
|
||||||
int32_t server_num = scale_in_message.server_num();
|
int32_t server_num = scale_in_message.server_num();
|
||||||
MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num;
|
MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num;
|
||||||
|
@ -789,7 +799,44 @@ void AbstractNode::ProcessScaleIn(const std::shared_ptr<TcpConnection> &conn, co
|
||||||
MS_LOG(WARNING) << "Server response message failed.";
|
MS_LOG(WARNING) << "Server response message failed.";
|
||||||
}
|
}
|
||||||
OnEventCallback(ClusterEvent::READY_FOR_SCALE_IN);
|
OnEventCallback(ClusterEvent::READY_FOR_SCALE_IN);
|
||||||
current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN;
|
UpdateClusterState(ClusterState::CLUSTER_SCALE_IN);
|
||||||
|
is_ready_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AbstractNode::ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> &conn,
|
||||||
|
const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
|
||||||
|
size_t size) {
|
||||||
|
MS_EXCEPTION_IF_NULL(conn);
|
||||||
|
MS_EXCEPTION_IF_NULL(meta);
|
||||||
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
|
if (is_connected_to_scheduler_.load()) {
|
||||||
|
MS_LOG(WARNING) << "This node has been connected to scheduler.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
SendMetadataMessage scheduler_recovery_message;
|
||||||
|
(void)scheduler_recovery_message.ParseFromArray(data, SizeToInt(size));
|
||||||
|
worker_num_ = scheduler_recovery_message.worker_num();
|
||||||
|
server_num_ = scheduler_recovery_message.server_num();
|
||||||
|
uint32_t rank_id = scheduler_recovery_message.rank_id();
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "[Scheduler Recovery]: The scheduler recovery worker num:" << worker_num_
|
||||||
|
<< ", the server num:" << server_num_ << ", the rank id: " << rank_id;
|
||||||
|
|
||||||
|
if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) {
|
||||||
|
MS_LOG(WARNING) << "[Scheduler Recovery]: Server response message failed.";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "[Scheduler Recovery]: Server response message success!.";
|
||||||
|
|
||||||
|
if (!InitClientToScheduler()) {
|
||||||
|
MS_LOG(WARNING) << "[Scheduler Recovery]: Server node connect to scheduler timedout!";
|
||||||
|
}
|
||||||
|
|
||||||
|
Register(client_to_scheduler_);
|
||||||
|
std::lock_guard<std::mutex> lock(client_mutex_);
|
||||||
|
connected_nodes_.clear();
|
||||||
|
MS_LOG(INFO) << "[Scheduler Recovery]: This node connect to scheduler successful!";
|
||||||
|
|
||||||
|
UpdateClusterState(ClusterState::CLUSTER_SCHEDULER_RECOVERY);
|
||||||
is_ready_ = false;
|
is_ready_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -819,6 +866,14 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AbstractNode::InitClientToServer() {
|
||||||
|
// create tcp client to myself in case of event dispatch failed when Send msg to server 0 failed
|
||||||
|
client_to_server_ = std::make_shared<TcpClient>(node_info_.ip_, node_info_.port_, config_.get());
|
||||||
|
MS_EXCEPTION_IF_NULL(client_to_server_);
|
||||||
|
client_to_server_->Init();
|
||||||
|
MS_LOG(INFO) << "The node start a tcp client to this node!";
|
||||||
|
}
|
||||||
|
|
||||||
bool AbstractNode::InitClientToScheduler() {
|
bool AbstractNode::InitClientToScheduler() {
|
||||||
if (config_ == nullptr) {
|
if (config_ == nullptr) {
|
||||||
MS_LOG(WARNING) << "The config is empty.";
|
MS_LOG(WARNING) << "The config is empty.";
|
||||||
|
@ -843,7 +898,6 @@ bool AbstractNode::InitClientToScheduler() {
|
||||||
MsException::Instance().SetException();
|
MsException::Instance().SetException();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
client_to_scheduler_->Init();
|
client_to_scheduler_->Init();
|
||||||
client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
|
client_to_scheduler_thread_ = std::make_unique<std::thread>([&]() {
|
||||||
MS_LOG(INFO) << "The node start a tcp client!";
|
MS_LOG(INFO) << "The node start a tcp client!";
|
||||||
|
@ -851,11 +905,14 @@ bool AbstractNode::InitClientToScheduler() {
|
||||||
});
|
});
|
||||||
client_to_scheduler_thread_->detach();
|
client_to_scheduler_thread_->detach();
|
||||||
|
|
||||||
|
client_to_scheduler_->set_connected_callback([&]() { is_connected_to_scheduler_ = true; });
|
||||||
|
|
||||||
client_to_scheduler_->set_disconnected_callback([&]() {
|
client_to_scheduler_->set_disconnected_callback([&]() {
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval));
|
std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval));
|
||||||
if (is_ready_.load() == false) {
|
if (is_ready_.load() == false) {
|
||||||
client_to_scheduler_->Init();
|
client_to_scheduler_->Init();
|
||||||
}
|
}
|
||||||
|
is_connected_to_scheduler_ = false;
|
||||||
});
|
});
|
||||||
bool wait_res = client_to_scheduler_->WaitConnected();
|
bool wait_res = client_to_scheduler_->WaitConnected();
|
||||||
if (!wait_res) {
|
if (!wait_res) {
|
||||||
|
@ -892,6 +949,9 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
|
||||||
case NodeCommand::COLLECTIVE_SEND_DATA:
|
case NodeCommand::COLLECTIVE_SEND_DATA:
|
||||||
MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
|
MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!";
|
||||||
break;
|
break;
|
||||||
|
case NodeCommand::SEND_EVENT:
|
||||||
|
MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a send_event command message response!";
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!";
|
||||||
}
|
}
|
||||||
|
@ -964,8 +1024,9 @@ void AbstractNode::ProcessSendData(const std::shared_ptr<TcpConnection> &conn, c
|
||||||
if (size > 0) {
|
if (size > 0) {
|
||||||
size_t dest_size = size;
|
size_t dest_size = size;
|
||||||
size_t src_size = size;
|
size_t src_size = size;
|
||||||
if (memcpy_s(res.get(), dest_size, data, src_size) != EOK) {
|
auto ret = memcpy_s(res.get(), dest_size, data, src_size);
|
||||||
MS_LOG(EXCEPTION) << "The memcpy_s error";
|
if (ret != EOK) {
|
||||||
|
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
|
||||||
|
@ -1066,6 +1127,7 @@ void AbstractNode::InitServerHandler() {
|
||||||
server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone;
|
server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone;
|
||||||
server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone;
|
server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone;
|
||||||
server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent;
|
server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent;
|
||||||
|
server_handler_[NodeCommand::SCHEDULER_RECOVERY] = &AbstractNode::ProcessSchedulerRecovery;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AbstractNode::InitNodeInfo(const NodeRole &role) {
|
void AbstractNode::InitNodeInfo(const NodeRole &role) {
|
||||||
|
@ -1090,8 +1152,8 @@ void AbstractNode::InitNodeInfo(const NodeRole &role) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AbstractNode::InitNodeNum() {
|
void AbstractNode::InitNodeNum() {
|
||||||
worker_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_worker_num);
|
worker_num_ = UintToInt(PSContext::instance()->cluster_config().initial_worker_num);
|
||||||
server_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_server_num);
|
server_num_ = UintToInt(PSContext::instance()->cluster_config().initial_server_num);
|
||||||
scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host;
|
scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host;
|
||||||
scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port;
|
scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port;
|
||||||
MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_
|
MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_
|
||||||
|
@ -1104,7 +1166,10 @@ bool AbstractNode::Recover() {
|
||||||
MS_LOG(INFO) << "The node is support recovery.";
|
MS_LOG(INFO) << "The node is support recovery.";
|
||||||
node_recovery_ = std::make_unique<NodeRecovery>(this);
|
node_recovery_ = std::make_unique<NodeRecovery>(this);
|
||||||
MS_EXCEPTION_IF_NULL(node_recovery_);
|
MS_EXCEPTION_IF_NULL(node_recovery_);
|
||||||
node_recovery_->Initialize(config_->Get(kKeyRecovery, ""));
|
if (!node_recovery_->Initialize(config_->Get(kKeyRecovery, ""))) {
|
||||||
|
MS_LOG(ERROR) << "Initializing node recovery failed.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return node_recovery_->Recover();
|
return node_recovery_->Recover();
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -1132,7 +1197,7 @@ void AbstractNode::OnCustomEventCallback(const uint32_t &event) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AbstractNode::IsWorkerOrServer0(const mindspore::HashMap<std::string, NodeInfo> &info) {
|
bool AbstractNode::IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info) {
|
||||||
for (const auto &it : info) {
|
for (const auto &it : info) {
|
||||||
if (it.second.is_alive == true && it.second.node_role_ == NodeRole::WORKER) {
|
if (it.second.is_alive == true && it.second.node_role_ == NodeRole::WORKER) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -1149,7 +1214,12 @@ void AbstractNode::CreateTcpServer() {
|
||||||
MS_EXCEPTION_IF_NULL(config_);
|
MS_EXCEPTION_IF_NULL(config_);
|
||||||
std::string interface;
|
std::string interface;
|
||||||
std::string server_ip;
|
std::string server_ip;
|
||||||
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
|
if (ps::PSContext::instance()->server_mode().empty()) {
|
||||||
|
// If the server mode is not set, use 127.0.0.1 as server ip address for distributed learning.
|
||||||
|
server_ip = "127.0.0.1";
|
||||||
|
} else {
|
||||||
|
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
|
||||||
|
}
|
||||||
server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get());
|
server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get());
|
||||||
MS_EXCEPTION_IF_NULL(server_);
|
MS_EXCEPTION_IF_NULL(server_);
|
||||||
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||||
|
@ -1179,6 +1249,29 @@ void AbstractNode::CreateTcpServer() {
|
||||||
MS_EXCEPTION_IF_NULL(server_thread_);
|
MS_EXCEPTION_IF_NULL(server_thread_);
|
||||||
server_thread_->detach();
|
server_thread_->detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AbstractNode::UpdateClusterState(const ClusterState &state) {
|
||||||
|
std::lock_guard<std::mutex> lk(cluster_state_mutex_);
|
||||||
|
MS_LOG(INFO) << "[state]: Cluster state change from:" << CommUtil::ClusterStateToString(current_cluster_state_)
|
||||||
|
<< " to " << CommUtil::ClusterStateToString(state);
|
||||||
|
current_cluster_state_ = state;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AbstractNode::PersistMetaData() {
|
||||||
|
if (node_recovery_ == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "node recovery is null, so don't persist meta data";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (config_->Exists(kKeyRecovery)) {
|
||||||
|
ClusterConfig &clusterConfig = PSContext::instance()->cluster_config();
|
||||||
|
clusterConfig.scheduler_host = this->scheduler_ip();
|
||||||
|
clusterConfig.scheduler_port = this->scheduler_port();
|
||||||
|
clusterConfig.initial_worker_num = worker_num_;
|
||||||
|
clusterConfig.initial_server_num = server_num_;
|
||||||
|
|
||||||
|
node_recovery_->Persist(clusterConfig);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,8 +22,8 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "utils/hash_map.h"
|
|
||||||
#include "ps/core/node.h"
|
#include "ps/core/node.h"
|
||||||
#include "ps/core/communicator/message.h"
|
#include "ps/core/communicator/message.h"
|
||||||
#include "ps/core/follower_scaler.h"
|
#include "ps/core/follower_scaler.h"
|
||||||
|
@ -44,10 +44,12 @@ class AbstractNode : public Node {
|
||||||
: heart_beat_thread_(nullptr),
|
: heart_beat_thread_(nullptr),
|
||||||
client_to_scheduler_thread_(nullptr),
|
client_to_scheduler_thread_(nullptr),
|
||||||
client_to_scheduler_(nullptr),
|
client_to_scheduler_(nullptr),
|
||||||
|
client_to_server_(nullptr),
|
||||||
server_(nullptr),
|
server_(nullptr),
|
||||||
server_thread_(nullptr),
|
server_thread_(nullptr),
|
||||||
worker_num_(-1),
|
worker_num_(-1),
|
||||||
server_num_(-1),
|
server_num_(-1),
|
||||||
|
is_connected_to_scheduler_(false),
|
||||||
is_current_node_scale_in_(false),
|
is_current_node_scale_in_(false),
|
||||||
follower_scaler_(nullptr),
|
follower_scaler_(nullptr),
|
||||||
node_recovery_(nullptr),
|
node_recovery_(nullptr),
|
||||||
|
@ -94,14 +96,14 @@ class AbstractNode : public Node {
|
||||||
void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb);
|
void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb);
|
||||||
|
|
||||||
bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
|
bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command,
|
||||||
const uint32_t &timeout = kTimeoutInSeconds);
|
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
|
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
|
||||||
const std::vector<size_t> &lens, int command, const uint32_t &timeout = kTimeoutInSeconds);
|
const std::vector<size_t> &lens, int command, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||||
bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command,
|
bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command,
|
||||||
VectorPtr *output, const uint32_t &timeout = kTimeoutInSeconds);
|
VectorPtr *output, const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||||
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
|
bool Send(const NodeRole &node_role, const std::vector<uint32_t> &rank_ids, const std::vector<DataPtr> &data,
|
||||||
const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
|
const std::vector<size_t> &data_lens, int command, std::vector<VectorPtr> *output,
|
||||||
const uint32_t &timeout = kTimeoutInSeconds);
|
const uint32_t &timeout = kCommTimeoutInSeconds);
|
||||||
|
|
||||||
uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
|
uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size);
|
||||||
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
|
std::pair<uint32_t, uint64_t> CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id,
|
||||||
|
@ -170,6 +172,10 @@ class AbstractNode : public Node {
|
||||||
void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
void ProcessScaleInDone(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||||
const Protos &protos, const void *data, size_t size);
|
const Protos &protos, const void *data, size_t size);
|
||||||
|
|
||||||
|
// The worker/server processes the scheduler recovery message from scheduelr
|
||||||
|
void ProcessSchedulerRecovery(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||||
|
const Protos &, const void *data, size_t size);
|
||||||
|
|
||||||
// The worker/server processes the SEND_EVENT message from scheduelr
|
// The worker/server processes the SEND_EVENT message from scheduelr
|
||||||
void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
void ProcessEvent(const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
|
||||||
const Protos &protos, const void *data, size_t size);
|
const Protos &protos, const void *data, size_t size);
|
||||||
|
@ -180,6 +186,7 @@ class AbstractNode : public Node {
|
||||||
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
|
bool Disconnect(const std::shared_ptr<TcpClient> &client, const uint32_t &timeout);
|
||||||
bool WaitForDisconnect(const uint32_t &timeout);
|
bool WaitForDisconnect(const uint32_t &timeout);
|
||||||
bool InitClientToScheduler();
|
bool InitClientToScheduler();
|
||||||
|
void InitClientToServer();
|
||||||
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id,
|
const std::shared_ptr<TcpClient> &GetOrCreateTcpClient(const uint32_t &rank_id,
|
||||||
const NodeRole &role = NodeRole::SERVER);
|
const NodeRole &role = NodeRole::SERVER);
|
||||||
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
bool SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
|
||||||
|
@ -212,14 +219,18 @@ class AbstractNode : public Node {
|
||||||
// Trigger the callback corresponding to the custom event.
|
// Trigger the callback corresponding to the custom event.
|
||||||
void OnCustomEventCallback(const uint32_t &event);
|
void OnCustomEventCallback(const uint32_t &event);
|
||||||
|
|
||||||
bool IsWorkerOrServer0(const mindspore::HashMap<std::string, NodeInfo> &info);
|
bool IsWorkerOrServer0(const std::unordered_map<std::string, NodeInfo> &info);
|
||||||
|
|
||||||
void CreateTcpServer();
|
void CreateTcpServer();
|
||||||
|
|
||||||
|
void UpdateClusterState(const ClusterState &state);
|
||||||
|
|
||||||
|
void PersistMetaData();
|
||||||
|
|
||||||
std::unique_ptr<std::thread> heart_beat_thread_;
|
std::unique_ptr<std::thread> heart_beat_thread_;
|
||||||
std::unique_ptr<std::thread> client_to_scheduler_thread_;
|
std::unique_ptr<std::thread> client_to_scheduler_thread_;
|
||||||
std::shared_ptr<TcpClient> client_to_scheduler_;
|
std::shared_ptr<TcpClient> client_to_scheduler_;
|
||||||
|
std::shared_ptr<TcpClient> client_to_server_;
|
||||||
// the key is: <node_role,rank_id>, the value is: <ip, port>
|
// the key is: <node_role,rank_id>, the value is: <ip, port>
|
||||||
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
|
std::map<std::pair<NodeRole, uint32_t>, std::pair<std::string, uint16_t>> nodes_address_;
|
||||||
// the map's key is: rank_id
|
// the map's key is: rank_id
|
||||||
|
@ -233,13 +244,13 @@ class AbstractNode : public Node {
|
||||||
std::condition_variable receive_cond_;
|
std::condition_variable receive_cond_;
|
||||||
|
|
||||||
// the key is rank_id, the value is rank_id's expected request_id
|
// the key is rank_id, the value is rank_id's expected request_id
|
||||||
mindspore::HashMap<uint32_t, uint64_t> expected_rank_request_ids_;
|
std::unordered_map<uint32_t, uint64_t> expected_rank_request_ids_;
|
||||||
// the key is rank_id, the value is rank_id's actual request_id
|
// the key is rank_id, the value is rank_id's actual request_id
|
||||||
mindspore::HashMap<uint32_t, uint64_t> actual_rank_request_ids_;
|
std::unordered_map<uint32_t, uint64_t> actual_rank_request_ids_;
|
||||||
std::mutex rank_request_ids_mutex;
|
std::mutex rank_request_ids_mutex;
|
||||||
timeval scheduler_time_{0, 0};
|
timeval scheduler_time_{0, 0};
|
||||||
mindspore::HashMap<NodeCommand, ResponseHandler> handlers_;
|
std::unordered_map<NodeCommand, ResponseHandler> handlers_;
|
||||||
mindspore::HashMap<NodeCommand, ServerHandler> server_handler_;
|
std::unordered_map<NodeCommand, ServerHandler> server_handler_;
|
||||||
|
|
||||||
// Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA
|
// Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA
|
||||||
std::shared_ptr<TcpServer> server_;
|
std::shared_ptr<TcpServer> server_;
|
||||||
|
@ -247,7 +258,7 @@ class AbstractNode : public Node {
|
||||||
|
|
||||||
int32_t worker_num_;
|
int32_t worker_num_;
|
||||||
int32_t server_num_;
|
int32_t server_num_;
|
||||||
|
std::atomic<bool> is_connected_to_scheduler_;
|
||||||
// Identify whether the current node is a scale in node.
|
// Identify whether the current node is a scale in node.
|
||||||
std::atomic<bool> is_current_node_scale_in_;
|
std::atomic<bool> is_current_node_scale_in_;
|
||||||
|
|
||||||
|
@ -273,11 +284,12 @@ class AbstractNode : public Node {
|
||||||
uint16_t scheduler_port_;
|
uint16_t scheduler_port_;
|
||||||
|
|
||||||
// Synchronize all node metadata from the scheduler.
|
// Synchronize all node metadata from the scheduler.
|
||||||
mindspore::HashMap<std::string, NodeInfo> all_nodes_info_;
|
std::unordered_map<std::string, NodeInfo> all_nodes_info_;
|
||||||
RequestHandler request_handler_;
|
RequestHandler request_handler_;
|
||||||
|
|
||||||
mindspore::HashMap<std::string, std::shared_ptr<CommunicatorBase>> communicators_;
|
std::unordered_map<std::string, std::shared_ptr<CommunicatorBase>> communicators_;
|
||||||
std::mutex communicator_mutex_;
|
std::mutex communicator_mutex_;
|
||||||
|
std::mutex cluster_state_mutex_;
|
||||||
};
|
};
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
|
|
|
@ -21,8 +21,10 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "ps/core/node_info.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
|
@ -38,10 +40,12 @@ struct ClusterConfig {
|
||||||
scheduler_host(host),
|
scheduler_host(host),
|
||||||
scheduler_port(port),
|
scheduler_port(port),
|
||||||
heartbeat_timeout(30),
|
heartbeat_timeout(30),
|
||||||
cluster_available_timeout(300),
|
cluster_available_timeout(900),
|
||||||
connect_interval(3000),
|
connect_interval(3000),
|
||||||
scheduler_timeout(30) {}
|
scheduler_timeout(30),
|
||||||
|
initial_total_node_num(0),
|
||||||
|
initial_next_worker_rank_id(0),
|
||||||
|
initial_next_server_rank_id(0) {}
|
||||||
// Configure through environment variables:MS_WORKER_NUM
|
// Configure through environment variables:MS_WORKER_NUM
|
||||||
uint32_t initial_worker_num;
|
uint32_t initial_worker_num;
|
||||||
// Configure through environment variables:MS_SERVER_NUM
|
// Configure through environment variables:MS_SERVER_NUM
|
||||||
|
@ -59,6 +63,11 @@ struct ClusterConfig {
|
||||||
uint32_t connect_interval;
|
uint32_t connect_interval;
|
||||||
// When the scheduler exits, the worker and server can continue to work for 5 hours
|
// When the scheduler exits, the worker and server can continue to work for 5 hours
|
||||||
int64_t scheduler_timeout;
|
int64_t scheduler_timeout;
|
||||||
|
// the node that has bean registered to scheduler
|
||||||
|
std::unordered_map<std::string, NodeInfo> initial_registered_nodes_infos;
|
||||||
|
uint32_t initial_total_node_num;
|
||||||
|
uint32_t initial_next_worker_rank_id;
|
||||||
|
uint32_t initial_next_server_rank_id;
|
||||||
};
|
};
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
|
|
|
@ -32,7 +32,6 @@ namespace core {
|
||||||
*/
|
*/
|
||||||
struct ClusterMetadata {
|
struct ClusterMetadata {
|
||||||
ClusterMetadata(const uint32_t &worker, const uint32_t &server) : worker_num(worker), server_num(server) {}
|
ClusterMetadata(const uint32_t &worker, const uint32_t &server) : worker_num(worker), server_num(server) {}
|
||||||
|
|
||||||
uint32_t worker_num;
|
uint32_t worker_num;
|
||||||
uint32_t server_num;
|
uint32_t server_num;
|
||||||
};
|
};
|
||||||
|
|
|
@ -96,6 +96,28 @@ void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *i
|
||||||
freeifaddrs(if_address);
|
freeifaddrs(if_address);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string CommUtil::GetLoopBackInterfaceName() {
|
||||||
|
struct ifaddrs *if_address = nullptr;
|
||||||
|
struct ifaddrs *ifa = nullptr;
|
||||||
|
|
||||||
|
if (getifaddrs(&if_address) == -1) {
|
||||||
|
MS_LOG(WARNING) << "Get ifaddrs failed.";
|
||||||
|
}
|
||||||
|
for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) {
|
||||||
|
if (ifa->ifa_addr == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ifa->ifa_flags & IFF_LOOPBACK) {
|
||||||
|
MS_LOG(INFO) << "Loop back interface name is " << ifa->ifa_name;
|
||||||
|
return ifa->ifa_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(if_address);
|
||||||
|
freeifaddrs(if_address);
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
std::string CommUtil::GenerateUUID() {
|
std::string CommUtil::GenerateUUID() {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
int i;
|
int i;
|
||||||
|
@ -135,6 +157,18 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) {
|
||||||
MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
|
MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NodeRole CommUtil::StringToNodeRole(const std::string &roleStr) {
|
||||||
|
if (roleStr == "SCHEDULER") {
|
||||||
|
return NodeRole::SCHEDULER;
|
||||||
|
} else if (roleStr == "SERVER") {
|
||||||
|
return NodeRole::SERVER;
|
||||||
|
} else if (roleStr == "WORKER") {
|
||||||
|
return NodeRole::WORKER;
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "The node role string:" << roleStr << " is illegal!";
|
||||||
|
}
|
||||||
|
}
|
||||||
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
|
bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
|
||||||
const int32_t &total_server_num) {
|
const int32_t &total_server_num) {
|
||||||
if (node_role == NodeRole::SERVER && (rank_id > IntToUint(total_server_num) - 1)) {
|
if (node_role == NodeRole::SERVER && (rank_id > IntToUint(total_server_num) - 1)) {
|
||||||
|
@ -183,7 +217,7 @@ bool CommUtil::IsFileExists(const std::string &file) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string CommUtil::ClusterStateToString(const ClusterState &state) {
|
std::string CommUtil::ClusterStateToString(const ClusterState &state) {
|
||||||
MS_LOG(INFO) << "The cluster state:" << state;
|
MS_LOG(DEBUG) << "The cluster state:" << state;
|
||||||
if (state < SizeToInt(kClusterState.size())) {
|
if (state < SizeToInt(kClusterState.size())) {
|
||||||
return kClusterState.at(state);
|
return kClusterState.at(state);
|
||||||
} else {
|
} else {
|
||||||
|
@ -191,21 +225,50 @@ std::string CommUtil::ClusterStateToString(const ClusterState &state) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string CommUtil::ParseConfig(const Configuration &config, const std::string &data) {
|
std::string CommUtil::ParseConfig(const Configuration &config, const std::string &key) {
|
||||||
if (!config.IsInitialized()) {
|
if (!config.IsInitialized()) {
|
||||||
MS_LOG(INFO) << "The config is not initialized.";
|
MS_LOG(INFO) << "The config is not initialized.";
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!const_cast<Configuration &>(config).Exists(data)) {
|
if (!const_cast<Configuration &>(config).Exists(key)) {
|
||||||
MS_LOG(INFO) << "The data:" << data << " is not exist.";
|
MS_LOG(INFO) << "The key:" << key << " is not exist.";
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string path = config.GetString(data, "");
|
std::string path = config.GetString(key, "");
|
||||||
return path;
|
return path;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CommUtil::verifyCertTimeStamp(const X509 *cert) {
|
||||||
|
ASN1_TIME *start = X509_getm_notBefore(cert);
|
||||||
|
ASN1_TIME *end = X509_getm_notAfter(cert);
|
||||||
|
|
||||||
|
int day = 0;
|
||||||
|
int sec = 0;
|
||||||
|
int ret = ASN1_TIME_diff(&day, &sec, start, NULL);
|
||||||
|
if (ret != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (day < 0 || sec < 0) {
|
||||||
|
MS_LOG(ERROR) << "cert start time is later than now time.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
day = 0;
|
||||||
|
sec = 0;
|
||||||
|
ret = ASN1_TIME_diff(&day, &sec, NULL, end);
|
||||||
|
if (ret != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (day < 0 || sec < 0) {
|
||||||
|
MS_LOG(ERROR) << "cert end time is sooner than now time.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool CommUtil::VerifyCertTime(const X509 *cert, int64_t time) {
|
bool CommUtil::VerifyCertTime(const X509 *cert, int64_t time) {
|
||||||
MS_EXCEPTION_IF_NULL(cert);
|
MS_EXCEPTION_IF_NULL(cert);
|
||||||
ASN1_TIME *start = X509_getm_notBefore(cert);
|
ASN1_TIME *start = X509_getm_notBefore(cert);
|
||||||
|
@ -249,61 +312,55 @@ bool CommUtil::VerifyCRL(const X509 *cert, const std::string &crl_path) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
|
||||||
BIO *bio = BIO_new_file(crl_path.c_str(), "r");
|
BIO *bio = BIO_new_file(crl_path.c_str(), "r");
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(bio, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(bio, false);
|
||||||
X509_CRL *root_crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(root_crl, false);
|
|
||||||
EVP_PKEY *evp_pkey = X509_get_pubkey(const_cast<X509 *>(cert));
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(evp_pkey, false);
|
|
||||||
|
|
||||||
int ret = X509_CRL_verify(root_crl, evp_pkey);
|
X509_CRL *root_crl = nullptr;
|
||||||
|
EVP_PKEY *evp_pkey = nullptr;
|
||||||
|
bool result = true;
|
||||||
|
do {
|
||||||
|
root_crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(root_crl, false);
|
||||||
|
evp_pkey = X509_get_pubkey(const_cast<X509 *>(cert));
|
||||||
|
MS_ERROR_IF_NULL_W_RET_VAL(evp_pkey, false);
|
||||||
|
|
||||||
|
int ret = X509_CRL_verify(root_crl, evp_pkey);
|
||||||
|
if (ret == 1) {
|
||||||
|
MS_LOG(WARNING) << "Equip cert in root crl, verify failed";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (0);
|
||||||
|
|
||||||
BIO_free_all(bio);
|
BIO_free_all(bio);
|
||||||
if (ret == 1) {
|
EVP_PKEY_free(evp_pkey);
|
||||||
MS_LOG(WARNING) << "Equip cert in root crl, verify failed";
|
X509_CRL_free(root_crl);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "VerifyCRL success.";
|
MS_LOG(INFO) << "VerifyCRL success.";
|
||||||
return true;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CommUtil::VerifyCommonName(const X509 *cert, const std::string &ca_path) {
|
bool CommUtil::VerifyCommonName(const X509 *caCert, const X509 *subCert) {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
|
MS_EXCEPTION_IF_NULL(caCert);
|
||||||
X509 *cert_temp = const_cast<X509 *>(cert);
|
MS_EXCEPTION_IF_NULL(subCert);
|
||||||
char subject_cn[256] = "";
|
char caSubjectCN[256] = "";
|
||||||
char issuer_cn[256] = "";
|
char subIssuerCN[256] = "";
|
||||||
X509_NAME *subject_name = X509_get_subject_name(cert_temp);
|
|
||||||
X509_NAME *issuer_name = X509_get_issuer_name(cert_temp);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(subject_name, false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(issuer_name, false);
|
|
||||||
if (!X509_NAME_get_text_by_NID(subject_name, NID_commonName, subject_cn, sizeof(subject_cn))) {
|
|
||||||
MS_LOG(WARNING) << "Get text by nid failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!X509_NAME_get_text_by_NID(issuer_name, NID_commonName, issuer_cn, sizeof(issuer_cn))) {
|
|
||||||
MS_LOG(WARNING) << "Get text by nid failed.";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
MS_LOG(INFO) << "the subject:" << subject_cn << ", the issuer:" << issuer_cn;
|
|
||||||
|
|
||||||
BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
|
X509_NAME *caSubjectX509CN = X509_get_subject_name(caCert);
|
||||||
MS_EXCEPTION_IF_NULL(ca_bio);
|
X509_NAME *subIssuerX509CN = X509_get_issuer_name(subCert);
|
||||||
X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
|
|
||||||
MS_EXCEPTION_IF_NULL(ca_cert);
|
int ret = X509_NAME_get_text_by_NID(caSubjectX509CN, NID_commonName, caSubjectCN, sizeof(caSubjectCN));
|
||||||
char ca_subject_cn[256] = "";
|
if (ret < 0) {
|
||||||
char ca_issuer_cn[256] = "";
|
|
||||||
X509_NAME *ca_subject_name = X509_get_subject_name(ca_cert);
|
|
||||||
X509_NAME *ca_issuer_name = X509_get_issuer_name(ca_cert);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(ca_subject_name, false);
|
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(ca_issuer_name, false);
|
|
||||||
if (!X509_NAME_get_text_by_NID(ca_subject_name, NID_commonName, ca_subject_cn, sizeof(subject_cn))) {
|
|
||||||
MS_LOG(WARNING) << "Get text by nid failed.";
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!X509_NAME_get_text_by_NID(ca_issuer_name, NID_commonName, ca_issuer_cn, sizeof(issuer_cn))) {
|
ret = X509_NAME_get_text_by_NID(subIssuerX509CN, NID_commonName, subIssuerCN, sizeof(subIssuerCN));
|
||||||
MS_LOG(WARNING) << "Get text by nid failed.";
|
if (ret < 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "the subject:" << ca_subject_cn << ", the issuer:" << ca_issuer_cn;
|
|
||||||
BIO_free_all(ca_bio);
|
std::string caSubjectCNStr = caSubjectCN;
|
||||||
if (strcmp(issuer_cn, ca_subject_cn) != 0) {
|
std::string subIssuerCNStr = subIssuerCN;
|
||||||
|
|
||||||
|
if (caSubjectCNStr != subIssuerCNStr) {
|
||||||
|
MS_LOG(EXCEPTION) << "root CA cert subject cn is not equal with equip CA cert issuer cn.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -330,19 +387,170 @@ bool CommUtil::VerifyCipherList(const std::vector<std::string> &list) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommUtil::InitOpenSSLEnv() {
|
bool CommUtil::verifyCertKeyID(const X509 *caCert, const X509 *subCert) {
|
||||||
if (!SSL_library_init()) {
|
MS_EXCEPTION_IF_NULL(caCert);
|
||||||
MS_LOG(EXCEPTION) << "SSL_library_init failed.";
|
MS_EXCEPTION_IF_NULL(subCert);
|
||||||
|
int crit = 0;
|
||||||
|
ASN1_OCTET_STRING *skid =
|
||||||
|
reinterpret_cast<ASN1_OCTET_STRING *>(X509_get_ext_d2i(caCert, NID_subject_key_identifier, &crit, NULL));
|
||||||
|
MS_EXCEPTION_IF_NULL(skid);
|
||||||
|
const int keyidLen = 512;
|
||||||
|
char subject_keyid[keyidLen] = {0};
|
||||||
|
for (int i = 0; i < skid->length; i++) {
|
||||||
|
char keyid[8] = {0};
|
||||||
|
int base = keyidLen;
|
||||||
|
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)skid->data[i]);
|
||||||
|
int ret = strcat_s(subject_keyid, base, keyid);
|
||||||
|
if (ret == -1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!ERR_load_crypto_strings()) {
|
|
||||||
MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
|
AUTHORITY_KEYID *akeyid =
|
||||||
|
reinterpret_cast<AUTHORITY_KEYID *>(X509_get_ext_d2i(subCert, NID_authority_key_identifier, &crit, NULL));
|
||||||
|
MS_EXCEPTION_IF_NULL(akeyid);
|
||||||
|
MS_EXCEPTION_IF_NULL(akeyid->keyid);
|
||||||
|
char issuer_keyid[keyidLen] = {0};
|
||||||
|
for (int i = 0; i < akeyid->keyid->length; i++) {
|
||||||
|
char keyid[8] = {0};
|
||||||
|
int base = keyidLen;
|
||||||
|
(void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)(akeyid->keyid->data[i]));
|
||||||
|
int ret = strcat_s(issuer_keyid, base, keyid);
|
||||||
|
if (ret == -1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (!SSL_load_error_strings()) {
|
|
||||||
MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
|
std::string subject_keyid_str = subject_keyid;
|
||||||
|
std::string issuer_keyid_str = issuer_keyid;
|
||||||
|
if (subject_keyid_str != issuer_keyid_str) {
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
if (!OpenSSL_add_all_algorithms()) {
|
return true;
|
||||||
MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
|
}
|
||||||
|
|
||||||
|
bool CommUtil::verifySingature(const X509 *caCert, const X509 *subCert) {
|
||||||
|
MS_EXCEPTION_IF_NULL(caCert);
|
||||||
|
MS_EXCEPTION_IF_NULL(subCert);
|
||||||
|
EVP_PKEY *caCertPubKey = X509_get_pubkey(const_cast<X509 *>(caCert));
|
||||||
|
|
||||||
|
int ret = 0;
|
||||||
|
ret = X509_verify(const_cast<X509 *>(subCert), caCertPubKey);
|
||||||
|
if (ret != 1) {
|
||||||
|
EVP_PKEY_free(caCertPubKey);
|
||||||
|
MS_LOG(ERROR) << "sub cert verify is failed";
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << "verifyCAChain success.";
|
||||||
|
EVP_PKEY_free(caCertPubKey);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CommUtil::verifyExtendedAttributes(const X509 *caCert) {
|
||||||
|
MS_EXCEPTION_IF_NULL(caCert);
|
||||||
|
int cirt = 0;
|
||||||
|
BASIC_CONSTRAINTS *bcons =
|
||||||
|
reinterpret_cast<BASIC_CONSTRAINTS *>(X509_get_ext_d2i(caCert, NID_basic_constraints, &cirt, NULL));
|
||||||
|
if (bcons == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!bcons->ca) {
|
||||||
|
MS_LOG(ERROR) << "Subject Type is End Entity.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Subject Type is CA.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommUtil::verifyCertPipeline(const X509 *caCert, const X509 *subCert) {
|
||||||
|
if (!CommUtil::VerifyCommonName(caCert, subCert)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Verify common name failed.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!CommUtil::verifySingature(caCert, subCert)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Verify Singature failed.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!CommUtil::verifyExtendedAttributes(caCert)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Verify Extended Attributes failed.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!CommUtil::verifyCertKeyID(caCert, subCert)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Verify Cert KeyID failed.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!CommUtil::verifyCertTimeStamp(caCert) || !CommUtil::verifyCertTimeStamp(subCert)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Verify Cert Time failed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CommUtil::checkCRLTime(const std::string &crlPath) {
|
||||||
|
if (!IsFileExists(crlPath)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
BIO *bio = BIO_new_file(crlPath.c_str(), "r");
|
||||||
|
if (bio == nullptr) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool result = true;
|
||||||
|
X509_CRL *crl = nullptr;
|
||||||
|
do {
|
||||||
|
crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
|
||||||
|
if (crl == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "crl is nullptr. return true.";
|
||||||
|
result = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const ASN1_TIME *lastUpdate = X509_CRL_get0_lastUpdate(crl);
|
||||||
|
const ASN1_TIME *nextUpdate = X509_CRL_get0_nextUpdate(crl);
|
||||||
|
|
||||||
|
int day = 0;
|
||||||
|
int sec = 0;
|
||||||
|
int ret = ASN1_TIME_diff(&day, &sec, lastUpdate, NULL);
|
||||||
|
if (ret != 1) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (day < 0 || sec < 0) {
|
||||||
|
MS_LOG(ERROR) << "crl start time is later than now time.";
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
day = 0;
|
||||||
|
sec = 0;
|
||||||
|
ret = ASN1_TIME_diff(&day, &sec, NULL, nextUpdate);
|
||||||
|
if (ret != 1) {
|
||||||
|
result = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (day < 0 || sec < 0) {
|
||||||
|
MS_LOG(WARNING) << "crl update time is sooner than now time. please update crl";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "verifyCRL time success.";
|
||||||
|
} while (0);
|
||||||
|
|
||||||
|
X509_CRL_free(crl);
|
||||||
|
BIO_free_all(bio);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CommUtil::BoolToString(bool alive) {
|
||||||
|
if (alive) {
|
||||||
|
return "True";
|
||||||
|
} else {
|
||||||
|
return "False";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CommUtil::StringToBool(const std::string &alive) {
|
||||||
|
if (alive == "True") {
|
||||||
|
return true;
|
||||||
|
} else if (alive == "False") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
|
|
|
@ -44,6 +44,7 @@
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <openssl/pkcs12.h>
|
#include <openssl/pkcs12.h>
|
||||||
#include <openssl/bio.h>
|
#include <openssl/bio.h>
|
||||||
|
#include <openssl/x509v3.h>
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
@ -99,8 +100,12 @@ class CommUtil {
|
||||||
static bool CheckIp(const std::string &ip);
|
static bool CheckIp(const std::string &ip);
|
||||||
static bool CheckPort(const uint16_t &port);
|
static bool CheckPort(const uint16_t &port);
|
||||||
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
|
static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip);
|
||||||
|
static std::string GetLoopBackInterfaceName();
|
||||||
static std::string GenerateUUID();
|
static std::string GenerateUUID();
|
||||||
static std::string NodeRoleToString(const NodeRole &role);
|
static std::string NodeRoleToString(const NodeRole &role);
|
||||||
|
static NodeRole StringToNodeRole(const std::string &roleStr);
|
||||||
|
static std::string BoolToString(bool alive);
|
||||||
|
static bool StringToBool(const std::string &alive);
|
||||||
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
|
static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num,
|
||||||
const int32_t &total_server_num);
|
const int32_t &total_server_num);
|
||||||
static bool Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds);
|
static bool Retry(const std::function<bool()> &func, size_t max_attempts, size_t interval_milliseconds);
|
||||||
|
@ -112,19 +117,21 @@ class CommUtil {
|
||||||
static std::string ClusterStateToString(const ClusterState &state);
|
static std::string ClusterStateToString(const ClusterState &state);
|
||||||
|
|
||||||
// Parse the configuration file according to the key.
|
// Parse the configuration file according to the key.
|
||||||
static std::string ParseConfig(const Configuration &config, const std::string &data);
|
static std::string ParseConfig(const Configuration &config, const std::string &key);
|
||||||
|
|
||||||
// verify valid of certificate time
|
// verify valid of certificate time
|
||||||
static bool VerifyCertTime(const X509 *cert, int64_t time = 0);
|
static bool VerifyCertTime(const X509 *cert, int64_t time = 0);
|
||||||
|
static bool verifyCertTimeStamp(const X509 *cert);
|
||||||
// verify valid of equip certificate with CRL
|
// verify valid of equip certificate with CRL
|
||||||
static bool VerifyCRL(const X509 *cert, const std::string &crl_path);
|
static bool VerifyCRL(const X509 *cert, const std::string &crl_path);
|
||||||
// Check the common name of the certificate
|
static bool VerifyCommonName(const X509 *caCert, const X509 *subCert);
|
||||||
static bool VerifyCommonName(const X509 *cert, const std::string &ca_path);
|
|
||||||
// The string is divided according to delim
|
|
||||||
static std::vector<std::string> Split(const std::string &s, char delim);
|
static std::vector<std::string> Split(const std::string &s, char delim);
|
||||||
// Check the cipher list of the certificate
|
|
||||||
static bool VerifyCipherList(const std::vector<std::string> &list);
|
static bool VerifyCipherList(const std::vector<std::string> &list);
|
||||||
static void InitOpenSSLEnv();
|
static bool verifyCertKeyID(const X509 *caCert, const X509 *subCert);
|
||||||
|
static bool verifySingature(const X509 *caCert, const X509 *subCert);
|
||||||
|
static bool verifyExtendedAttributes(const X509 *caCert);
|
||||||
|
static void verifyCertPipeline(const X509 *caCert, const X509 *subCert);
|
||||||
|
static bool checkCRLTime(const std::string &crlPath);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static std::random_device rd;
|
static std::random_device rd;
|
||||||
|
|
|
@ -43,7 +43,7 @@ void CommunicatorBase::Join() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CommunicatorBase::running() const { return running_; }
|
bool CommunicatorBase::running() { return running_; }
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -19,10 +19,10 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
#include "utils/hash_map.h"
|
|
||||||
#include "ps/core/communicator/message_handler.h"
|
#include "ps/core/communicator/message_handler.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "ps/core/communicator/http_message_handler.h"
|
#include "ps/core/communicator/http_message_handler.h"
|
||||||
|
@ -57,6 +57,7 @@ enum class TcpUserCommand {
|
||||||
kQueryInstance,
|
kQueryInstance,
|
||||||
kEnableFLS,
|
kEnableFLS,
|
||||||
kDisableFLS,
|
kDisableFLS,
|
||||||
|
kSyncAfterRecover,
|
||||||
kExchangeKeys,
|
kExchangeKeys,
|
||||||
kGetKeys
|
kGetKeys
|
||||||
};
|
};
|
||||||
|
@ -84,10 +85,10 @@ class CommunicatorBase {
|
||||||
|
|
||||||
bool SendResponse(const void *rsp_data, size_t rsp_len, const std::shared_ptr<MessageHandler> &msg_handler);
|
bool SendResponse(const void *rsp_data, size_t rsp_len, const std::shared_ptr<MessageHandler> &msg_handler);
|
||||||
|
|
||||||
bool running() const;
|
bool running();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
mindspore::HashMap<std::string, MessageCallback> msg_callbacks_;
|
std::unordered_map<std::string, MessageCallback> msg_callbacks_;
|
||||||
std::thread running_thread_;
|
std::thread running_thread_;
|
||||||
bool running_;
|
bool running_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -22,11 +22,11 @@ namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace core {
|
namespace core {
|
||||||
bool HttpCommunicator::Start() {
|
bool HttpCommunicator::Start() {
|
||||||
|
MS_EXCEPTION_IF_NULL(http_server_);
|
||||||
MS_LOG(INFO) << "Initialize http server IP:" << ip_ << ", PORT:" << port_;
|
MS_LOG(INFO) << "Initialize http server IP:" << ip_ << ", PORT:" << port_;
|
||||||
if (!http_server_->InitServer()) {
|
if (!http_server_->InitServer()) {
|
||||||
MS_LOG(EXCEPTION) << "The communicator init http server failed.";
|
MS_LOG(EXCEPTION) << "The communicator init http server failed.";
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(http_server_);
|
|
||||||
if (!http_server_->Start()) {
|
if (!http_server_->Start()) {
|
||||||
MS_LOG(EXCEPTION) << "Http server starting failed.";
|
MS_LOG(EXCEPTION) << "Http server starting failed.";
|
||||||
}
|
}
|
||||||
|
@ -51,9 +51,11 @@ bool HttpCommunicator::Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void HttpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) {
|
void HttpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) {
|
||||||
|
MS_LOG(INFO) << "msg_type is: " << msg_type;
|
||||||
msg_callbacks_[msg_type] = cb;
|
msg_callbacks_[msg_type] = cb;
|
||||||
http_msg_callbacks_[msg_type] = std::bind(
|
http_msg_callbacks_[msg_type] = std::bind(
|
||||||
[&](std::shared_ptr<HttpMessageHandler> http_msg) -> void {
|
[&](std::shared_ptr<HttpMessageHandler> http_msg) -> void {
|
||||||
|
MS_EXCEPTION_IF_NULL(http_msg);
|
||||||
std::shared_ptr<MessageHandler> http_msg_handler = std::make_shared<HttpMsgHandler>(http_msg);
|
std::shared_ptr<MessageHandler> http_msg_handler = std::make_shared<HttpMsgHandler>(http_msg);
|
||||||
MS_EXCEPTION_IF_NULL(http_msg_handler);
|
MS_EXCEPTION_IF_NULL(http_msg_handler);
|
||||||
msg_callbacks_[msg_type](http_msg_handler);
|
msg_callbacks_[msg_type](http_msg_handler);
|
||||||
|
@ -61,7 +63,8 @@ void HttpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const Me
|
||||||
},
|
},
|
||||||
std::placeholders::_1);
|
std::placeholders::_1);
|
||||||
|
|
||||||
std::string url = "/";
|
std::string url = ps::PSContext::instance()->http_url_prefix();
|
||||||
|
url += "/";
|
||||||
url += msg_type;
|
url += msg_type;
|
||||||
MS_EXCEPTION_IF_NULL(http_server_);
|
MS_EXCEPTION_IF_NULL(http_server_);
|
||||||
bool is_succeed = http_server_->RegisterRoute(url, &http_msg_callbacks_[msg_type]);
|
bool is_succeed = http_server_->RegisterRoute(url, &http_msg_callbacks_[msg_type]);
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "utils/hash_map.h"
|
#include <unordered_map>
|
||||||
#include "ps/core/communicator/http_server.h"
|
#include "ps/core/communicator/http_server.h"
|
||||||
#include "ps/core/communicator/http_message_handler.h"
|
#include "ps/core/communicator/http_message_handler.h"
|
||||||
#include "ps/core/communicator/task_executor.h"
|
#include "ps/core/communicator/task_executor.h"
|
||||||
|
@ -46,7 +46,7 @@ class HttpCommunicator : public CommunicatorBase {
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<TaskExecutor> task_executor_;
|
std::shared_ptr<TaskExecutor> task_executor_;
|
||||||
std::shared_ptr<HttpServer> http_server_;
|
std::shared_ptr<HttpServer> http_server_;
|
||||||
mindspore::HashMap<std::string, HttpMsgCallback> http_msg_callbacks_;
|
std::unordered_map<std::string, HttpMsgCallback> http_msg_callbacks_;
|
||||||
|
|
||||||
std::string ip_;
|
std::string ip_;
|
||||||
uint16_t port_;
|
uint16_t port_;
|
||||||
|
|
|
@ -265,7 +265,7 @@ void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, co
|
||||||
}
|
}
|
||||||
|
|
||||||
void HttpMessageHandler::ErrorResponse(int code, const RequestProcessResult &result) {
|
void HttpMessageHandler::ErrorResponse(int code, const RequestProcessResult &result) {
|
||||||
nlohmann::json error_json = {{"error_message", result.StatusMessage()}};
|
nlohmann::json error_json = {{"error_message", result.StatusMessage()}, {"code", kErrorCode}};
|
||||||
std::string out_error = error_json.dump();
|
std::string out_error = error_json.dump();
|
||||||
AddRespString(out_error);
|
AddRespString(out_error);
|
||||||
SetRespCode(code);
|
SetRespCode(code);
|
||||||
|
|
|
@ -26,7 +26,7 @@ HttpRequestHandler::~HttpRequestHandler() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HttpRequestHandler::Initialize(int fd, const mindspore::HashMap<std::string, OnRequestReceive *> &handlers) {
|
bool HttpRequestHandler::Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers) {
|
||||||
evbase_ = event_base_new();
|
evbase_ = event_base_new();
|
||||||
MS_EXCEPTION_IF_NULL(evbase_);
|
MS_EXCEPTION_IF_NULL(evbase_);
|
||||||
struct evhttp *http = evhttp_new(evbase_);
|
struct evhttp *http = evhttp_new(evbase_);
|
||||||
|
@ -115,8 +115,7 @@ bufferevent *HttpRequestHandler::BuffereventCallback(event_base *base, void *arg
|
||||||
SSL_CTX *ctx = reinterpret_cast<SSL_CTX *>(arg);
|
SSL_CTX *ctx = reinterpret_cast<SSL_CTX *>(arg);
|
||||||
SSL *ssl = SSL_new(ctx);
|
SSL *ssl = SSL_new(ctx);
|
||||||
MS_EXCEPTION_IF_NULL(ssl);
|
MS_EXCEPTION_IF_NULL(ssl);
|
||||||
bufferevent *bev =
|
bufferevent *bev = bufferevent_openssl_socket_new(base, -1, ssl, BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE);
|
||||||
bufferevent_openssl_socket_new(base, -1, ssl, BUFFEREVENT_SSL_ACCEPTING, static_cast<int>(BEV_OPT_CLOSE_ON_FREE));
|
|
||||||
MS_EXCEPTION_IF_NULL(bev);
|
MS_EXCEPTION_IF_NULL(bev);
|
||||||
return bev;
|
return bev;
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,8 +25,8 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "utils/hash_map.h"
|
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "ps/core/communicator/http_message_handler.h"
|
#include "ps/core/communicator/http_message_handler.h"
|
||||||
#include "ps/core/communicator/ssl_http.h"
|
#include "ps/core/communicator/ssl_http.h"
|
||||||
|
@ -46,7 +46,7 @@ class HttpRequestHandler {
|
||||||
HttpRequestHandler() : evbase_(nullptr) {}
|
HttpRequestHandler() : evbase_(nullptr) {}
|
||||||
virtual ~HttpRequestHandler();
|
virtual ~HttpRequestHandler();
|
||||||
|
|
||||||
bool Initialize(int fd, const mindspore::HashMap<std::string, OnRequestReceive *> &handlers);
|
bool Initialize(int fd, const std::unordered_map<std::string, OnRequestReceive *> &handlers);
|
||||||
void Run();
|
void Run();
|
||||||
bool Stop();
|
bool Stop();
|
||||||
static bufferevent *BuffereventCallback(event_base *base, void *arg);
|
static bufferevent *BuffereventCallback(event_base *base, void *arg);
|
||||||
|
|
|
@ -68,7 +68,7 @@ bool HttpServer::InitServer() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
fd_ = ::socket(static_cast<int>(AF_INET), static_cast<int>(SOCK_STREAM), 0);
|
fd_ = ::socket(AF_INET, SOCK_STREAM, 0);
|
||||||
if (fd_ < 0) {
|
if (fd_ < 0) {
|
||||||
MS_LOG(ERROR) << "Socker error!";
|
MS_LOG(ERROR) << "Socker error!";
|
||||||
return false;
|
return false;
|
||||||
|
@ -84,7 +84,8 @@ bool HttpServer::InitServer() {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct sockaddr_in addr;
|
struct sockaddr_in addr;
|
||||||
if (memset_s(&addr, sizeof(addr), 0, sizeof(addr)) != EOK) {
|
errno_t ret = memset_s(&addr, sizeof(addr), 0, sizeof(addr));
|
||||||
|
if (ret != EOK) {
|
||||||
MS_LOG(EXCEPTION) << "Memset failed.";
|
MS_LOG(EXCEPTION) << "Memset failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,6 +133,7 @@ bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *functio
|
||||||
if (!function) {
|
if (!function) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
MS_LOG(INFO) << "request handler url is: " << url;
|
||||||
request_handlers_[url] = function;
|
request_handlers_[url] = function;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -17,6 +17,8 @@
|
||||||
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_
|
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_
|
||||||
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_
|
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_
|
||||||
|
|
||||||
|
#include "ps/core/communicator/http_message_handler.h"
|
||||||
|
|
||||||
#include <event2/buffer.h>
|
#include <event2/buffer.h>
|
||||||
#include <event2/event.h>
|
#include <event2/event.h>
|
||||||
#include <event2/http.h>
|
#include <event2/http.h>
|
||||||
|
@ -34,10 +36,9 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "utils/hash_map.h"
|
|
||||||
#include "ps/core/communicator/http_message_handler.h"
|
|
||||||
#include "ps/core/communicator/http_request_handler.h"
|
#include "ps/core/communicator/http_request_handler.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -77,7 +78,7 @@ class HttpServer {
|
||||||
std::vector<std::shared_ptr<std::thread>> worker_threads_;
|
std::vector<std::shared_ptr<std::thread>> worker_threads_;
|
||||||
std::vector<std::shared_ptr<HttpRequestHandler>> http_request_handlers;
|
std::vector<std::shared_ptr<HttpRequestHandler>> http_request_handlers;
|
||||||
int32_t backlog_;
|
int32_t backlog_;
|
||||||
mindspore::HashMap<std::string, OnRequestReceive *> request_handlers_;
|
std::unordered_map<std::string, OnRequestReceive *> request_handlers_;
|
||||||
int fd_;
|
int fd_;
|
||||||
};
|
};
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
|
@ -37,7 +38,18 @@ SSLClient::SSLClient() : ssl_ctx_(nullptr), check_time_thread_(nullptr), running
|
||||||
SSLClient::~SSLClient() { CleanSSL(); }
|
SSLClient::~SSLClient() { CleanSSL(); }
|
||||||
|
|
||||||
void SSLClient::InitSSL() {
|
void SSLClient::InitSSL() {
|
||||||
CommUtil::InitOpenSSLEnv();
|
if (!SSL_library_init()) {
|
||||||
|
MS_LOG(EXCEPTION) << "SSL_library_init failed.";
|
||||||
|
}
|
||||||
|
if (!ERR_load_crypto_strings()) {
|
||||||
|
MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
|
||||||
|
}
|
||||||
|
if (!SSL_load_error_strings()) {
|
||||||
|
MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
|
||||||
|
}
|
||||||
|
if (!OpenSSL_add_all_algorithms()) {
|
||||||
|
MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
|
||||||
|
}
|
||||||
ssl_ctx_ = SSL_CTX_new(SSLv23_client_method());
|
ssl_ctx_ = SSL_CTX_new(SSLv23_client_method());
|
||||||
if (!ssl_ctx_) {
|
if (!ssl_ctx_) {
|
||||||
MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
|
MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
|
||||||
|
@ -65,6 +77,7 @@ void SSLClient::InitSSL() {
|
||||||
EVP_PKEY *pkey = nullptr;
|
EVP_PKEY *pkey = nullptr;
|
||||||
X509 *cert = nullptr;
|
X509 *cert = nullptr;
|
||||||
STACK_OF(X509) *ca_stack = nullptr;
|
STACK_OF(X509) *ca_stack = nullptr;
|
||||||
|
MS_LOG(INFO) << "cliet cert: " << client_cert;
|
||||||
BIO *bio = BIO_new_file(client_cert.c_str(), "rb");
|
BIO *bio = BIO_new_file(client_cert.c_str(), "rb");
|
||||||
MS_EXCEPTION_IF_NULL(bio);
|
MS_EXCEPTION_IF_NULL(bio);
|
||||||
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
|
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
|
||||||
|
@ -83,27 +96,26 @@ void SSLClient::InitSSL() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. load ca cert.
|
// 3. load ca cert.
|
||||||
std::string client_ca = kCAcrt;
|
|
||||||
std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
|
std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
|
||||||
if (!CommUtil::IsFileExists(ca_path)) {
|
if (!CommUtil::IsFileExists(ca_path)) {
|
||||||
MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
|
MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
|
||||||
}
|
}
|
||||||
client_ca = ca_path;
|
BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
|
||||||
|
MS_EXCEPTION_IF_NULL(ca_bio);
|
||||||
|
X509 *caCert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
|
||||||
std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
|
std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
|
||||||
if (crl_path.empty()) {
|
if (crl_path.empty()) {
|
||||||
MS_LOG(INFO) << "The crl path is empty.";
|
MS_LOG(INFO) << "The crl path is empty.";
|
||||||
|
} else if (!CommUtil::checkCRLTime(crl_path)) {
|
||||||
|
MS_LOG(EXCEPTION) << "check crl time failed";
|
||||||
} else if (!CommUtil::VerifyCRL(cert, crl_path)) {
|
} else if (!CommUtil::VerifyCRL(cert, crl_path)) {
|
||||||
MS_LOG(EXCEPTION) << "Verify crl failed.";
|
MS_LOG(EXCEPTION) << "Verify crl failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!CommUtil::VerifyCommonName(cert, client_ca)) {
|
CommUtil::verifyCertPipeline(caCert, cert);
|
||||||
MS_LOG(EXCEPTION) << "Verify common name failed.";
|
|
||||||
}
|
|
||||||
|
|
||||||
SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 0);
|
SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 0);
|
||||||
|
if (!SSL_CTX_load_verify_locations(ssl_ctx_, ca_path.c_str(), nullptr)) {
|
||||||
if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) {
|
|
||||||
MS_LOG(EXCEPTION) << "SSL load ca location failed!";
|
MS_LOG(EXCEPTION) << "SSL load ca location failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -115,13 +127,21 @@ void SSLClient::InitSSL() {
|
||||||
if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
|
if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
|
||||||
MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
|
MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
|
||||||
}
|
}
|
||||||
|
InitSSLCtx(cert, pkey);
|
||||||
|
StartCheckCertTime(*config_, cert);
|
||||||
|
|
||||||
// 4. load client cert
|
EVP_PKEY_free(pkey);
|
||||||
if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) {
|
X509_free(caCert);
|
||||||
|
X509_free(cert);
|
||||||
|
(void)BIO_free(ca_bio);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SSLClient::InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey) {
|
||||||
|
if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast<X509 *>(cert))) {
|
||||||
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
|
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) {
|
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast<EVP_PKEY *>(pkey))) {
|
||||||
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
|
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,7 +158,7 @@ void SSLClient::InitSSL() {
|
||||||
MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
|
MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
StartCheckCertTime(*config_, cert);
|
SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SSLClient::CleanSSL() {
|
void SSLClient::CleanSSL() {
|
||||||
|
|
|
@ -37,6 +37,7 @@
|
||||||
#include "ps/core/comm_util.h"
|
#include "ps/core/comm_util.h"
|
||||||
#include "ps/constants.h"
|
#include "ps/constants.h"
|
||||||
#include "ps/core/file_configuration.h"
|
#include "ps/core/file_configuration.h"
|
||||||
|
#include "ps/ps_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
|
@ -60,6 +61,7 @@ class SSLClient {
|
||||||
|
|
||||||
void StartCheckCertTime(const Configuration &config, const X509 *cert);
|
void StartCheckCertTime(const Configuration &config, const X509 *cert);
|
||||||
void StopCheckCertTime();
|
void StopCheckCertTime();
|
||||||
|
void InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey);
|
||||||
|
|
||||||
SSL_CTX *ssl_ctx_;
|
SSL_CTX *ssl_ctx_;
|
||||||
std::unique_ptr<std::thread> check_time_thread_;
|
std::unique_ptr<std::thread> check_time_thread_;
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
|
@ -35,7 +36,18 @@ SSLHTTP::SSLHTTP() : ssl_ctx_(nullptr) { InitSSL(); }
|
||||||
SSLHTTP::~SSLHTTP() { CleanSSL(); }
|
SSLHTTP::~SSLHTTP() { CleanSSL(); }
|
||||||
|
|
||||||
void SSLHTTP::InitSSL() {
|
void SSLHTTP::InitSSL() {
|
||||||
CommUtil::InitOpenSSLEnv();
|
if (!SSL_library_init()) {
|
||||||
|
MS_LOG(EXCEPTION) << "SSL_library_init failed.";
|
||||||
|
}
|
||||||
|
if (!ERR_load_crypto_strings()) {
|
||||||
|
MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
|
||||||
|
}
|
||||||
|
if (!SSL_load_error_strings()) {
|
||||||
|
MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
|
||||||
|
}
|
||||||
|
if (!OpenSSL_add_all_algorithms()) {
|
||||||
|
MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
|
||||||
|
}
|
||||||
ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
|
ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
|
||||||
if (!ssl_ctx_) {
|
if (!ssl_ctx_) {
|
||||||
MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
|
MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
|
||||||
|
@ -79,25 +91,39 @@ void SSLHTTP::InitSSL() {
|
||||||
MS_LOG(EXCEPTION) << "PKCS12_parse failed.";
|
MS_LOG(EXCEPTION) << "PKCS12_parse failed.";
|
||||||
}
|
}
|
||||||
PKCS12_free(p12);
|
PKCS12_free(p12);
|
||||||
|
if (cert == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "the cert is nullptr";
|
||||||
|
}
|
||||||
|
if (pkey == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "the key is nullptr";
|
||||||
|
}
|
||||||
|
if (!CommUtil::verifyCertTimeStamp(cert)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Verify Cert Time failed.";
|
||||||
|
}
|
||||||
std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList);
|
std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList);
|
||||||
|
InitSSLCtx(cert, pkey, default_cipher_list);
|
||||||
|
EVP_PKEY_free(pkey);
|
||||||
|
X509_free(cert);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SSLHTTP::InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey, const std::string &default_cipher_list) {
|
||||||
if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
|
if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
|
||||||
MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
|
MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
|
||||||
}
|
}
|
||||||
|
if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast<X509 *>(cert))) {
|
||||||
if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) {
|
|
||||||
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
|
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
|
||||||
}
|
}
|
||||||
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) {
|
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast<EVP_PKEY *>(pkey))) {
|
||||||
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
|
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
|
||||||
}
|
}
|
||||||
if (!SSL_CTX_check_private_key(ssl_ctx_)) {
|
if (!SSL_CTX_check_private_key(ssl_ctx_)) {
|
||||||
MS_LOG(EXCEPTION) << "SSL check private key file failed!";
|
MS_LOG(EXCEPTION) << "SSL check private key file failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
|
if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 |
|
||||||
SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
|
SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
|
||||||
MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
|
MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
|
||||||
}
|
}
|
||||||
|
SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SSLHTTP::CleanSSL() {
|
void SSLHTTP::CleanSSL() {
|
||||||
|
|
|
@ -53,7 +53,7 @@ class SSLHTTP {
|
||||||
|
|
||||||
void InitSSL();
|
void InitSSL();
|
||||||
void CleanSSL();
|
void CleanSSL();
|
||||||
|
void InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey, const std::string &default_cipher_list);
|
||||||
SSL_CTX *ssl_ctx_;
|
SSL_CTX *ssl_ctx_;
|
||||||
};
|
};
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
|
@ -43,7 +44,18 @@ SSLWrapper::SSLWrapper()
|
||||||
SSLWrapper::~SSLWrapper() { CleanSSL(); }
|
SSLWrapper::~SSLWrapper() { CleanSSL(); }
|
||||||
|
|
||||||
void SSLWrapper::InitSSL() {
|
void SSLWrapper::InitSSL() {
|
||||||
CommUtil::InitOpenSSLEnv();
|
if (!SSL_library_init()) {
|
||||||
|
MS_LOG(EXCEPTION) << "SSL_library_init failed.";
|
||||||
|
}
|
||||||
|
if (!ERR_load_crypto_strings()) {
|
||||||
|
MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
|
||||||
|
}
|
||||||
|
if (!SSL_load_error_strings()) {
|
||||||
|
MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
|
||||||
|
}
|
||||||
|
if (!OpenSSL_add_all_algorithms()) {
|
||||||
|
MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
|
||||||
|
}
|
||||||
ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
|
ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
|
||||||
if (!ssl_ctx_) {
|
if (!ssl_ctx_) {
|
||||||
MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
|
MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
|
||||||
|
@ -100,31 +112,42 @@ void SSLWrapper::InitSSL() {
|
||||||
std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
|
std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath);
|
||||||
if (crl_path.empty()) {
|
if (crl_path.empty()) {
|
||||||
MS_LOG(INFO) << "The crl path is empty.";
|
MS_LOG(INFO) << "The crl path is empty.";
|
||||||
|
} else if (!CommUtil::checkCRLTime(crl_path)) {
|
||||||
|
MS_LOG(EXCEPTION) << "check crl time failed";
|
||||||
} else if (!CommUtil::VerifyCRL(cert, crl_path)) {
|
} else if (!CommUtil::VerifyCRL(cert, crl_path)) {
|
||||||
MS_LOG(EXCEPTION) << "Verify crl failed.";
|
MS_LOG(EXCEPTION) << "Verify crl failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string client_ca = kCAcrt;
|
|
||||||
std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
|
std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
|
||||||
if (!CommUtil::IsFileExists(ca_path)) {
|
if (!CommUtil::IsFileExists(ca_path)) {
|
||||||
MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
|
MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
|
||||||
}
|
}
|
||||||
client_ca = ca_path;
|
BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
|
||||||
|
MS_EXCEPTION_IF_NULL(ca_bio);
|
||||||
|
X509 *caCert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
if (!CommUtil::VerifyCommonName(cert, client_ca)) {
|
CommUtil::verifyCertPipeline(caCert, cert);
|
||||||
MS_LOG(EXCEPTION) << "Verify common name failed.";
|
|
||||||
}
|
|
||||||
|
|
||||||
SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER, 0);
|
SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER, 0);
|
||||||
if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) {
|
if (!SSL_CTX_load_verify_locations(ssl_ctx_, ca_path.c_str(), nullptr)) {
|
||||||
MS_LOG(EXCEPTION) << "SSL load ca location failed!";
|
MS_LOG(EXCEPTION) << "SSL load ca location failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) {
|
InitSSLCtx(cert, pkey);
|
||||||
|
StartCheckCertTime(*config_, cert, ca_path);
|
||||||
|
|
||||||
|
EVP_PKEY_free(pkey);
|
||||||
|
X509_free(caCert);
|
||||||
|
X509_free(cert);
|
||||||
|
(void)BIO_free(ca_bio);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SSLWrapper::InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey) {
|
||||||
|
if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast<X509 *>(cert))) {
|
||||||
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
|
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) {
|
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast<EVP_PKEY *>(pkey))) {
|
||||||
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
|
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,12 +158,10 @@ void SSLWrapper::InitSSL() {
|
||||||
SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
|
SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) {
|
||||||
MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
|
MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) {
|
if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) {
|
||||||
MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
|
MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!";
|
||||||
}
|
}
|
||||||
|
SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel);
|
||||||
StartCheckCertTime(*config_, cert, client_ca);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SSLWrapper::CleanSSL() {
|
void SSLWrapper::CleanSSL() {
|
||||||
|
|
|
@ -60,6 +60,7 @@ class SSLWrapper {
|
||||||
time_t ConvertAsn1Time(const ASN1_TIME *const time) const;
|
time_t ConvertAsn1Time(const ASN1_TIME *const time) const;
|
||||||
void StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path);
|
void StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path);
|
||||||
void StopCheckCertTime();
|
void StopCheckCertTime();
|
||||||
|
void InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey);
|
||||||
|
|
||||||
SSL_CTX *ssl_ctx_;
|
SSL_CTX *ssl_ctx_;
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,6 @@ void TcpClient::Init() {
|
||||||
if (event_base_ == nullptr) {
|
if (event_base_ == nullptr) {
|
||||||
event_base_ = event_base_new();
|
event_base_ = event_base_new();
|
||||||
MS_EXCEPTION_IF_NULL(event_base_);
|
MS_EXCEPTION_IF_NULL(event_base_);
|
||||||
is_stop_ = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sockaddr_in sin{};
|
sockaddr_in sin{};
|
||||||
|
@ -160,7 +159,7 @@ void TcpClient::Stop() {
|
||||||
|
|
||||||
void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
|
void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) {
|
||||||
const int one = 1;
|
const int one = 1;
|
||||||
int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
|
int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
MS_LOG(EXCEPTION) << "Set socket no delay failed!";
|
MS_LOG(EXCEPTION) << "Set socket no delay failed!";
|
||||||
}
|
}
|
||||||
|
@ -178,10 +177,10 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) {
|
||||||
auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
|
auto tcp_client = reinterpret_cast<TcpClient *>(ctx);
|
||||||
|
|
||||||
char read_buffer[kMessageChunkLength];
|
char read_buffer[kMessageChunkLength];
|
||||||
size_t read = 0;
|
int read = 0;
|
||||||
|
|
||||||
while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) {
|
while ((read = bufferevent_read(bev, &read_buffer, SizeToInt(sizeof(read_buffer)))) > 0) {
|
||||||
tcp_client->OnReadHandler(read_buffer, read);
|
tcp_client->OnReadHandler(read_buffer, IntToSize(read));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,6 +251,8 @@ void TcpClient::Start() {
|
||||||
event_base_mutex_.unlock();
|
event_base_mutex_.unlock();
|
||||||
MS_EXCEPTION_IF_NULL(event_base_);
|
MS_EXCEPTION_IF_NULL(event_base_);
|
||||||
int ret = event_base_dispatch(event_base_);
|
int ret = event_base_dispatch(event_base_);
|
||||||
|
// is_started_ should be false when finish dispatch
|
||||||
|
is_started_ = false;
|
||||||
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
|
MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!";
|
||||||
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
|
MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType)
|
||||||
<< "Event base dispatch failed with no events pending or active!";
|
<< "Event base dispatch failed with no events pending or active!";
|
||||||
|
|
|
@ -90,6 +90,7 @@ bool TcpCommunicator::Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TcpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) {
|
void TcpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) {
|
||||||
|
MS_LOG(INFO) << "msg_type is: " << msg_type;
|
||||||
msg_callbacks_.try_emplace(msg_type, cb);
|
msg_callbacks_.try_emplace(msg_type, cb);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include "utils/hash_map.h"
|
#include <unordered_map>
|
||||||
#include "proto/ps.pb.h"
|
#include "proto/ps.pb.h"
|
||||||
#include "ps/core/server_node.h"
|
#include "ps/core/server_node.h"
|
||||||
#include "ps/core/cluster_metadata.h"
|
#include "ps/core/cluster_metadata.h"
|
||||||
|
@ -36,7 +36,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace core {
|
namespace core {
|
||||||
const mindspore::HashMap<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
||||||
{TcpUserCommand::kPush, "push"},
|
{TcpUserCommand::kPush, "push"},
|
||||||
{TcpUserCommand::kPull, "pull"},
|
{TcpUserCommand::kPull, "pull"},
|
||||||
{TcpUserCommand::kCount, "count"},
|
{TcpUserCommand::kCount, "count"},
|
||||||
|
@ -61,7 +61,8 @@ const mindspore::HashMap<TcpUserCommand, std::string> kUserCommandToMsgType = {
|
||||||
{TcpUserCommand::kNewInstance, "newInstance"},
|
{TcpUserCommand::kNewInstance, "newInstance"},
|
||||||
{TcpUserCommand::kQueryInstance, "queryInstance"},
|
{TcpUserCommand::kQueryInstance, "queryInstance"},
|
||||||
{TcpUserCommand::kEnableFLS, "enableFLS"},
|
{TcpUserCommand::kEnableFLS, "enableFLS"},
|
||||||
{TcpUserCommand::kDisableFLS, "disableFLS"}};
|
{TcpUserCommand::kDisableFLS, "disableFLS"},
|
||||||
|
{TcpUserCommand::kSyncAfterRecover, "syncAfterRecover"}};
|
||||||
|
|
||||||
class TcpCommunicator : public CommunicatorBase {
|
class TcpCommunicator : public CommunicatorBase {
|
||||||
public:
|
public:
|
||||||
|
@ -92,8 +93,9 @@ class TcpCommunicator : public CommunicatorBase {
|
||||||
MS_ERROR_IF_NULL_W_RET_VAL(msg, false);
|
MS_ERROR_IF_NULL_W_RET_VAL(msg, false);
|
||||||
size_t dest_size = msg_str.size();
|
size_t dest_size = msg_str.size();
|
||||||
size_t src_size = msg_str.size();
|
size_t src_size = msg_str.size();
|
||||||
if (memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size) != EOK) {
|
auto ret = memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size);
|
||||||
MS_LOG(EXCEPTION) << "Memcpy_s error";
|
if (ret != EOK) {
|
||||||
|
MS_LOG(EXCEPTION) << "memcpy_s error, error no " << ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output != nullptr) {
|
if (output != nullptr) {
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ps {
|
namespace ps {
|
||||||
namespace core {
|
namespace core {
|
||||||
TcpMsgHandler::TcpMsgHandler(AbstractNode *const abstract_node, const std::shared_ptr<core::TcpConnection> &conn,
|
TcpMsgHandler::TcpMsgHandler(AbstractNode *abstract_node, const std::shared_ptr<core::TcpConnection> &conn,
|
||||||
const std::shared_ptr<MessageMeta> &meta, const DataPtr &data, size_t size)
|
const std::shared_ptr<MessageMeta> &meta, DataPtr data, size_t size)
|
||||||
: abstract_node_(abstract_node), tcp_conn_(conn), meta_(meta), data_ptr_(data), data_(nullptr), len_(size) {
|
: abstract_node_(abstract_node), tcp_conn_(conn), meta_(meta), data_ptr_(data), data_(nullptr), len_(size) {
|
||||||
if (data_ptr_ != nullptr) {
|
if (data_ptr_ != nullptr) {
|
||||||
data_ = data_ptr_.get();
|
data_ = data_ptr_.get();
|
||||||
|
|
|
@ -28,8 +28,8 @@ namespace ps {
|
||||||
namespace core {
|
namespace core {
|
||||||
class TcpMsgHandler : public MessageHandler {
|
class TcpMsgHandler : public MessageHandler {
|
||||||
public:
|
public:
|
||||||
TcpMsgHandler(AbstractNode *const abstract_node, const std::shared_ptr<core::TcpConnection> &conn,
|
TcpMsgHandler(AbstractNode *abstract_node, const std::shared_ptr<core::TcpConnection> &conn,
|
||||||
const std::shared_ptr<MessageMeta> &meta, const DataPtr &data, size_t size);
|
const std::shared_ptr<MessageMeta> &meta, DataPtr data, size_t size);
|
||||||
~TcpMsgHandler() override = default;
|
~TcpMsgHandler() override = default;
|
||||||
|
|
||||||
void *data() const override;
|
void *data() const override;
|
||||||
|
|
|
@ -175,8 +175,9 @@ void TcpServer::Init() {
|
||||||
listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
|
listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast<void *>(this),
|
||||||
LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
|
LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1,
|
||||||
reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
|
reinterpret_cast<struct sockaddr *>(&sin), sizeof(sin));
|
||||||
|
if (listener_ == nullptr) {
|
||||||
MS_EXCEPTION_IF_NULL(listener_);
|
MS_LOG(EXCEPTION) << "bind ip & port failed. please check.";
|
||||||
|
}
|
||||||
|
|
||||||
if (server_port_ == 0) {
|
if (server_port_ == 0) {
|
||||||
struct sockaddr_in sin_bound {};
|
struct sockaddr_in sin_bound {};
|
||||||
|
@ -306,6 +307,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
|
||||||
});
|
});
|
||||||
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
|
bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback,
|
||||||
reinterpret_cast<void *>(conn.get()));
|
reinterpret_cast<void *>(conn.get()));
|
||||||
|
MS_LOG(INFO) << "A client is connected, fd is " << fd;
|
||||||
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
|
if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) {
|
||||||
MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
|
MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!";
|
||||||
}
|
}
|
||||||
|
@ -411,7 +413,7 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) {
|
||||||
|
|
||||||
void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) {
|
void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) {
|
||||||
const int one = 1;
|
const int one = 1;
|
||||||
int ret = setsockopt(fd, static_cast<int>(IPPROTO_TCP), static_cast<int>(TCP_NODELAY), &one, sizeof(int));
|
int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int));
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
MS_LOG(EXCEPTION) << "Set socket no delay failed!";
|
MS_LOG(EXCEPTION) << "Set socket no delay failed!";
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,9 +25,10 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "utils/hash_map.h"
|
|
||||||
#include "ps/constants.h"
|
#include "ps/constants.h"
|
||||||
|
#include "nlohmann/json.hpp"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -51,6 +52,9 @@ class Configuration {
|
||||||
// Get configuration data from database or config file.
|
// Get configuration data from database or config file.
|
||||||
virtual std::string GetString(const std::string &key, const std::string &defaultvalue) const = 0;
|
virtual std::string GetString(const std::string &key, const std::string &defaultvalue) const = 0;
|
||||||
|
|
||||||
|
// Get configuration vector data from database or config file.
|
||||||
|
virtual std::vector<nlohmann::json> GetVector(const std::string &key) const = 0;
|
||||||
|
|
||||||
// Get configuration data from database or config file.
|
// Get configuration data from database or config file.
|
||||||
virtual int64_t GetInt(const std::string &key, int64_t default_value) const = 0;
|
virtual int64_t GetInt(const std::string &key, int64_t default_value) const = 0;
|
||||||
|
|
||||||
|
@ -59,6 +63,12 @@ class Configuration {
|
||||||
|
|
||||||
// Determine whether the configuration item exists.
|
// Determine whether the configuration item exists.
|
||||||
virtual bool Exists(const std::string &key) const = 0;
|
virtual bool Exists(const std::string &key) const = 0;
|
||||||
|
|
||||||
|
// storage meta data
|
||||||
|
virtual void PersistFile(const core::ClusterConfig &clusterConfig) const = 0;
|
||||||
|
|
||||||
|
// storage meta data without nodes
|
||||||
|
virtual void PersistNodes(const core::ClusterConfig &clusterConfig) const = 0;
|
||||||
};
|
};
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
|
|
|
@ -50,6 +50,15 @@ std::string FileConfiguration::Get(const std::string &key, const std::string &de
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<nlohmann::json> FileConfiguration::GetVector(const std::string &key) const {
|
||||||
|
if (!js.contains(key)) {
|
||||||
|
MS_LOG(WARNING) << "The key:" << key << " is not exist.";
|
||||||
|
return std::vector<nlohmann::json>();
|
||||||
|
}
|
||||||
|
|
||||||
|
return js.at(key);
|
||||||
|
}
|
||||||
|
|
||||||
std::string FileConfiguration::GetString(const std::string &key, const std::string &defaultvalue) const {
|
std::string FileConfiguration::GetString(const std::string &key, const std::string &defaultvalue) const {
|
||||||
if (!js.contains(key)) {
|
if (!js.contains(key)) {
|
||||||
MS_LOG(WARNING) << "The key:" << key << " is not exist.";
|
MS_LOG(WARNING) << "The key:" << key << " is not exist.";
|
||||||
|
@ -68,13 +77,7 @@ int64_t FileConfiguration::GetInt(const std::string &key, int64_t default_value)
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FileConfiguration::Put(const std::string &key, const std::string &value) {
|
void FileConfiguration::Put(const std::string &key, const std::string &value) { js[key] = value; }
|
||||||
std::ofstream output_file(file_path_);
|
|
||||||
js[key] = value;
|
|
||||||
output_file << js.dump();
|
|
||||||
|
|
||||||
output_file.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool FileConfiguration::Exists(const std::string &key) const {
|
bool FileConfiguration::Exists(const std::string &key) const {
|
||||||
if (!js.contains(key)) {
|
if (!js.contains(key)) {
|
||||||
|
@ -82,6 +85,53 @@ bool FileConfiguration::Exists(const std::string &key) const {
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FileConfiguration::PersistNodes(const core::ClusterConfig &clusterConfig) const {
|
||||||
|
if (!CommUtil::IsFileExists(file_path_)) {
|
||||||
|
MS_LOG(WARNING) << "The file path:" << file_path_ << " is not exist. create one";
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json persist_js;
|
||||||
|
persist_js[kRecoveryTotalNodeNum] = clusterConfig.initial_total_node_num;
|
||||||
|
persist_js[kRecoveryNextWorkerRankId] = clusterConfig.initial_next_worker_rank_id;
|
||||||
|
persist_js[kRecoveryNextServerRankId] = clusterConfig.initial_next_server_rank_id;
|
||||||
|
|
||||||
|
auto node_infos = clusterConfig.initial_registered_nodes_infos;
|
||||||
|
for (const auto kvs : node_infos) {
|
||||||
|
std::unordered_map<std::string, std::string> res;
|
||||||
|
res["ip"] = kvs.second.ip_;
|
||||||
|
res["port"] = std::to_string(kvs.second.port_);
|
||||||
|
res["node_id"] = kvs.second.node_id_;
|
||||||
|
res["rank_id"] = std::to_string(kvs.second.rank_id_);
|
||||||
|
res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_);
|
||||||
|
res["alive"] = CommUtil::BoolToString(kvs.second.is_alive);
|
||||||
|
persist_js["node_ids"].push_back(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ofstream output_file(file_path_);
|
||||||
|
output_file << persist_js.dump();
|
||||||
|
|
||||||
|
output_file.close();
|
||||||
|
MS_LOG(INFO) << "The nodes meta data persist to " << file_path_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FileConfiguration::PersistFile(const core::ClusterConfig &clusterConfig) const {
|
||||||
|
if (!CommUtil::IsFileExists(file_path_)) {
|
||||||
|
MS_LOG(WARNING) << "The file path:" << file_path_ << " is not exist. create one";
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json persist_js;
|
||||||
|
persist_js[kRecoveryWorkerNum] = clusterConfig.initial_worker_num;
|
||||||
|
persist_js[kRecoveryServerNum] = clusterConfig.initial_server_num;
|
||||||
|
persist_js[kRecoverySchedulerIp] = clusterConfig.scheduler_host;
|
||||||
|
persist_js[kRecoverySchedulerPort] = clusterConfig.scheduler_port;
|
||||||
|
|
||||||
|
std::ofstream output_file(file_path_);
|
||||||
|
output_file << persist_js.dump();
|
||||||
|
|
||||||
|
output_file.close();
|
||||||
|
MS_LOG(INFO) << "The meta data persist to " << file_path_;
|
||||||
|
}
|
||||||
} // namespace core
|
} // namespace core
|
||||||
} // namespace ps
|
} // namespace ps
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue