diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h index ce95d308b0d..3109bb679a5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_pull_weight_kernel.h @@ -30,7 +30,7 @@ namespace mindspore { namespace kernel { -// The duration between two PullWeights requests when return code is ResponseCode_SucNotReady. +// The duration between two PullWeight requests when return code is ResponseCode_SucNotReady. constexpr int kRetryDurationOfPullWeights = 200; template class FusedPullWeightKernel : public CPUKernel { @@ -52,11 +52,11 @@ class FusedPullWeightKernel : public CPUKernel { total_iteration_++; uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration(); if (step_num_per_iteration == 0) { - MS_LOG(EXCEPTION) << "Step numbers of per iteration should not be equal to 0"; + MS_LOG(EXCEPTION) << "step number per iteration should not be 0"; } - // The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server. MS_LOG(INFO) << "Try to pull weights. Local step number: " << total_iteration_ << ", step number needs to run per iteration: " << step_num_per_iteration; + // The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server. if (step_num_per_iteration != fl::kOneStepPerIteration && total_iteration_ % step_num_per_iteration != fl::kTrainBeginStepNum) { return true; @@ -86,6 +86,7 @@ class FusedPullWeightKernel : public CPUKernel { MS_EXCEPTION_IF_NULL(pull_weight_rsp_msg); pull_weight_rsp = flatbuffers::GetRoot(pull_weight_rsp_msg->data()); + MS_EXCEPTION_IF_NULL(pull_weight_rsp); retcode = pull_weight_rsp->retcode(); if (retcode == schema::ResponseCode_SucNotReady) { std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPullWeights)); @@ -95,12 +96,12 @@ class FusedPullWeightKernel : public CPUKernel { // Recreate fbb to avoid memory leak of FlatBuffers. fbb = std::make_shared(); if (!BuildPullWeightReq(fbb)) { - MS_LOG(EXCEPTION) << "Building request for FusedDownloadWeightsByKeys failed."; + MS_LOG(EXCEPTION) << "Building request for FusedPullWeight failed."; } continue; } else if (retcode != schema::ResponseCode_SUCCEED) { - MS_LOG(EXCEPTION) << "FusedPullWeight failed. Server return code: " << pull_weight_rsp->retcode() - << ", reason: " << pull_weight_rsp->reason()->str(); + MS_LOG(WARNING) << "FusedPullWeight failed. Server return code: " << pull_weight_rsp->retcode() + << ", reason: " << pull_weight_rsp->reason()->str(); } else { MS_LOG(DEBUG) << "FusedPullWeight succeed."; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h index 288af63fb9d..b08c1843c58 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/fused_push_weight_kernel.h @@ -28,7 +28,7 @@ namespace mindspore { namespace kernel { -// The duration between two PushWeights requests when return code is ResponseCode_SucNotReady. +// The duration between two PushWeight requests when return code is ResponseCode_SucNotReady. constexpr int kRetryDurationOfPushWeights = 200; template class FusedPushWeightKernel : public CPUKernel { @@ -50,11 +50,11 @@ class FusedPushWeightKernel : public CPUKernel { total_iteration_++; uint64_t step_num_per_iteration = fl::worker::FLWorker::GetInstance().worker_step_num_per_iteration(); if (step_num_per_iteration == 0) { - MS_LOG(EXCEPTION) << "Step numbers of per iteration should not be equal to 0"; + MS_LOG(EXCEPTION) << "step number per iterationb should not be 0"; } - // The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server. MS_LOG(INFO) << "Try to push weights. Local step number: " << total_iteration_ << ", step number needs to run per iteration: " << step_num_per_iteration; + // The worker has to train kWorkerTrainStepNum standalone iterations before it communicates with server. if (step_num_per_iteration != fl::kOneStepPerIteration && total_iteration_ % step_num_per_iteration != fl::kTrainEndStepNum) { return true; @@ -87,6 +87,7 @@ class FusedPushWeightKernel : public CPUKernel { MS_EXCEPTION_IF_NULL(push_weight_rsp_msg); push_weight_rsp = flatbuffers::GetRoot(push_weight_rsp_msg->data()); + MS_EXCEPTION_IF_NULL(push_weight_rsp); retcode = push_weight_rsp->retcode(); if (retcode == schema::ResponseCode_SucNotReady) { std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDurationOfPushWeights)); @@ -98,8 +99,8 @@ class FusedPushWeightKernel : public CPUKernel { } continue; } else if (retcode != schema::ResponseCode_SUCCEED) { - MS_LOG(EXCEPTION) << "FusedPushWeight failed. Server return code: " << push_weight_rsp->retcode() - << ", reason: " << push_weight_rsp->reason()->str(); + MS_LOG(WARNING) << "FusedPushWeight failed. Server return code: " << push_weight_rsp->retcode() + << ", reason: " << push_weight_rsp->reason()->str(); } else { MS_LOG(DEBUG) << "FusedPushWeight succeed."; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/push_metrics_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/push_metrics_kernel.h index 3bba3ecc169..c5095e79138 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/push_metrics_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/fl/push_metrics_kernel.h @@ -87,7 +87,7 @@ class PushMetricsKernel : public CPUKernel { case schema::ResponseCode_OutOfTime: break; default: - MS_LOG(EXCEPTION) << "Launching push metrics for worker failed."; + MS_LOG(WARNING) << "Launching push metrics for worker failed."; } MS_LOG(INFO) << "Push metrics for loss and accuracy success."; diff --git a/mindspore/ccsrc/fl/CMakeLists.txt b/mindspore/ccsrc/fl/CMakeLists.txt index fd40a754baa..6b0114f381b 100644 --- a/mindspore/ccsrc/fl/CMakeLists.txt +++ b/mindspore/ccsrc/fl/CMakeLists.txt @@ -5,7 +5,6 @@ if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/fed_avg_kernel.cc") - list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/sgd_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel_factory.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/round_kernel.cc") @@ -20,6 +19,8 @@ if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_secrets_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/reconstruct_secrets_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/share_secrets_kernel.cc") + list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_list_sign_kernel.cc") + list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/get_list_sign_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/round/push_metrics_kernel.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/kernel/params_info.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/consistent_hash_ring.cc") @@ -35,6 +36,8 @@ if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM _FL_SRC_FILES "server/model_store.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/round.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/server.cc") + list(REMOVE_ITEM _FL_SRC_FILES "server/cert_verify.cc") + list(REMOVE_ITEM _FL_SRC_FILES "server/server_recovery.cc") list(REMOVE_ITEM _FL_SRC_FILES "server/iteration_metrics.cc") list(REMOVE_ITEM _FL_SRC_FILES "worker/fl_worker.cc") list(REMOVE_ITEM _FL_SRC_FILES "armour/secure_protocol/encrypt.cc") @@ -49,10 +52,6 @@ if(NOT ENABLE_CPU OR WIN32) list(REMOVE_ITEM _FL_SRC_FILES "armour/cipher/cipher_unmask.cc") endif() -if(CMAKE_SYSTEM_NAME MATCHES "Darwin") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-tautological-pointer-compare") -endif() - list(LENGTH _FL_SRC_FILES fl_file_num) if(NOT fl_file_num EQUAL 0) set_property(SOURCE ${_FL_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_FL) diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc index aedac9b9900..6cf41a3c35a 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_init.cc @@ -24,8 +24,8 @@ namespace mindspore { namespace armour { bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt, size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt, - size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt, - size_t cipher_reconstruct_secrets_up_cnt) { + size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt, + size_t cipher_get_list_sign_cnt, size_t cipher_reconstruct_secrets_up_cnt) { MS_LOG(INFO) << "CipherInit::Init START"; if (publicparam_.p == nullptr || param.p == nullptr || param.prime == nullptr || publicparam_.prime == nullptr) { MS_LOG(ERROR) << "CipherInit::input data invalid."; @@ -47,6 +47,8 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size get_secrets_threshold = cipher_get_secrets_cnt; client_list_threshold = cipher_get_clientlist_cnt; reconstruct_secrets_threshold = cipher_reconstruct_secrets_up_cnt; + push_list_sign_threshold = cipher_push_list_sign_cnt; + get_list_sign_threshold = cipher_get_list_sign_cnt; time_out_mutex_ = time_out_mutex; publicparam_.dp_eps = param.dp_eps; @@ -74,6 +76,8 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size MS_LOG(INFO) << " CipherInit get_secrets_threshold : " << get_secrets_threshold; MS_LOG(INFO) << " CipherInit client_list_threshold : " << client_list_threshold; MS_LOG(INFO) << " CipherInit reconstruct_secrets_threshold : " << reconstruct_secrets_threshold; + MS_LOG(INFO) << " CipherInit push_list_sign_threshold : " << push_list_sign_threshold; + MS_LOG(INFO) << " CipherInit get_list_sign_threshold : " << get_list_sign_threshold; MS_LOG(INFO) << " CipherInit featuremap_ : " << featuremap_; if (!Check_Parames()) { MS_LOG(ERROR) << "Cipher parameters are illegal."; @@ -81,7 +85,6 @@ bool CipherInit::Init(const CipherPublicPara ¶m, size_t time_out_mutex, size } MS_LOG(INFO) << " CipherInit::Init Success"; } - if (param.encrypt_type == mindspore::ps::kStablePWEncryptType) { cipher_meta_storage_.RegisterStablePWClass(); MS_LOG(INFO) << "Register metadata for StablePWEncrypt is finished."; @@ -96,9 +99,10 @@ bool CipherInit::Check_Parames() { } if (share_secrets_threshold < reconstruct_secrets_threshold) { - MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger " - "than share_secrets_threshold, but got they are:" - << reconstruct_secrets_threshold << ", " << share_secrets_threshold; + MS_LOG(ERROR) << "reconstruct_secrets_threshold should not be larger than " + "share_secrets_threshold." + << "reconstruct_secrets_threshold: " << reconstruct_secrets_threshold + << ", share_secrets_threshold: " << share_secrets_threshold; return false; } diff --git a/mindspore/ccsrc/fl/armour/cipher/cipher_init.h b/mindspore/ccsrc/fl/armour/cipher/cipher_init.h index 0f93be5665c..af2e829618c 100644 --- a/mindspore/ccsrc/fl/armour/cipher/cipher_init.h +++ b/mindspore/ccsrc/fl/armour/cipher/cipher_init.h @@ -40,27 +40,34 @@ class CipherInit { // Initialize the parameters of the secure aggregation. bool Init(const CipherPublicPara ¶m, size_t time_out_mutex, size_t cipher_exchange_keys_cnt, size_t cipher_get_keys_cnt, size_t cipher_share_secrets_cnt, size_t cipher_get_secrets_cnt, - size_t cipher_get_clientlist_cnt, size_t cipher_reconstruct_secrets_down_cnt, + size_t cipher_get_clientlist_cnt, size_t cipher_push_list_sign_cnt, size_t cipher_get_list_sign_cnt, size_t cipher_reconstruct_secrets_up_cnt); // Get public params. which is given to start fl job thread. CipherPublicPara *GetPublicParams() { return &publicparam_; } - size_t share_secrets_threshold; // the minimum number of clients to share secret fragments. - size_t reconstruct_secrets_threshold; // the minimum number of clients to reconstruct secret mask. - size_t exchange_key_threshold; // the minimum number of clients to send public keys. - size_t push_list_sign_threshold; // the minimum number of clients to push client list signature. - size_t secrets_minnums_; // the minimum number of secret fragment s to reconstruct secret mask. + size_t share_secrets_threshold; // the minimum number of clients to share + // secret fragments. + size_t reconstruct_secrets_threshold; // the minimum number of clients to + // reconstruct secret mask. + size_t exchange_key_threshold; // the minimum number of clients to send public + // keys. + size_t push_list_sign_threshold; // the minimum number of clients to push + // client list signature. + size_t secrets_minnums_; // the minimum number of secret fragment s to + // reconstruct secret mask. size_t featuremap_; // the size of data to deal. - - CipherPublicPara publicparam_; // the param containing encrypted public parameters. + CipherPublicPara publicparam_; // the param containing encrypted public parameters. CipherMetaStorage cipher_meta_storage_; private: - size_t client_list_threshold; // the minimum number of clients to get update model client list. + size_t client_list_threshold; // the minimum number of clients to get update + // model client list. size_t get_key_threshold; // the minimum number of clients to get public keys. - size_t get_list_sign_threshold; // the minimum number of clients to get client list signature. - size_t get_secrets_threshold; // the minimum number of clients to get secret fragments. + size_t get_list_sign_threshold; // the minimum number of clients to get client + // list signature. + size_t get_secrets_threshold; // the minimum number of clients to get secret + // fragments. size_t time_out_mutex_; // timeout mutex. // Check whether the parameters are valid. diff --git a/mindspore/ccsrc/fl/server/cert_verify.cc b/mindspore/ccsrc/fl/server/cert_verify.cc new file mode 100644 index 00000000000..ea7e23eec23 --- /dev/null +++ b/mindspore/ccsrc/fl/server/cert_verify.cc @@ -0,0 +1,633 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "fl/server/cert_verify.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mindspore { +namespace ps { +namespace server { +#ifndef _WIN32 +static int64_t replayAttackTimeDiff; +X509 *CertVerify::readCertFromFile(const std::string &certPath) { + BIO *bio = BIO_new_file(certPath.c_str(), "r"); + X509 *certObj = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); + BIO_free_all(bio); + return certObj; +} + +X509 *CertVerify::readCertFromPerm(std::string cert) { + BIO *bio = BIO_new_mem_buf(reinterpret_cast(cert.data()), -1); + X509 *certObj = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); + BIO_free_all(bio); + return certObj; +} + +X509_CRL *CertVerify::readCrlFromFile(const std::string &crlPath) { + BIO *bio = BIO_new_file(crlPath.c_str(), "r"); + X509_CRL *crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr); + BIO_free_all(bio); + return crl; +} + +bool checkFileExists(const std::string &file) { + std::ifstream f(file.c_str()); + if (!f.good()) { + return false; + } else { + f.close(); + return true; + } +} + +bool CertVerify::verifyCertTime(const X509 *cert) const { + ASN1_TIME *start = X509_getm_notBefore(cert); + ASN1_TIME *end = X509_getm_notAfter(cert); + + int day = 0; + int sec = 0; + int ret = ASN1_TIME_diff(&day, &sec, start, NULL); + if (ret != 1) { + return false; + } + + if (day < 0 || sec < 0) { + MS_LOG(ERROR) << "cert start time is later than now time."; + return false; + } + day = 0; + sec = 0; + ret = ASN1_TIME_diff(&day, &sec, NULL, end); + if (ret != 1) { + return false; + } + + if (day < 0 || sec < 0) { + MS_LOG(ERROR) << "cert end time is sooner than now time."; + return false; + } + + return true; +} + +bool CertVerify::verifyPublicKey(const X509 *keyAttestationCertObj, const X509 *equipCertObj, + const X509 *equipCACertObj, const X509 *rootFirstCA, const X509 *rootSecondCA) const { + bool result = true; + EVP_PKEY *equipPubKey = X509_get_pubkey(const_cast(equipCertObj)); + EVP_PKEY *equipCAPubKey = X509_get_pubkey(const_cast(equipCACertObj)); + EVP_PKEY *rootFirstPubKey = X509_get_pubkey(const_cast(rootFirstCA)); + EVP_PKEY *rootSecondPubKey = X509_get_pubkey(const_cast(rootSecondCA)); + do { + int ret = 0; + ret = X509_verify(const_cast(keyAttestationCertObj), equipPubKey); + if (ret != 1) { + MS_LOG(ERROR) << "keyAttestationCert verify is failed"; + result = false; + break; + } + ret = X509_verify(const_cast(equipCertObj), equipCAPubKey); + if (ret != 1) { + MS_LOG(ERROR) << "equip cert verify is failed"; + result = false; + break; + } + int ret_first = X509_verify(const_cast(equipCACertObj), rootFirstPubKey); + int ret_second = X509_verify(const_cast(equipCACertObj), rootSecondPubKey); + if (ret_first != 1 && ret_second != 1) { + MS_LOG(ERROR) << "equip ca cert verify is failed"; + result = false; + break; + } + } while (0); + + EVP_PKEY_free(equipPubKey); + EVP_PKEY_free(equipCAPubKey); + EVP_PKEY_free(rootFirstPubKey); + EVP_PKEY_free(rootSecondPubKey); + MS_LOG(INFO) << "verify Public Key success."; + return result; +} + +bool CertVerify::verifyCAChain(const std::string &keyAttestation, const std::string &equipCert, + const std::string &equipCACert, const std::string &rootFirstCAPath, + const std::string &rootSecondCAPath) { + X509 *rootFirstCA = CertVerify::readCertFromFile(rootFirstCAPath); + X509 *rootSecondCA = CertVerify::readCertFromFile(rootSecondCAPath); + X509 *keyAttestationCertObj = readCertFromPerm(keyAttestation); + X509 *equipCertObj = readCertFromPerm(equipCert); + X509 *equipCACertObj = readCertFromPerm(equipCACert); + bool result = true; + do { + if (rootFirstCA == nullptr || rootSecondCA == nullptr) { + MS_LOG(ERROR) << "rootFirstCA or rootSecondCA is nullptr"; + result = false; + break; + } + if (keyAttestationCertObj == nullptr || equipCertObj == nullptr || equipCACertObj == nullptr) { + result = false; + break; + } + + if (!verifyCertTime(keyAttestationCertObj) || !verifyCertTime(equipCertObj) || !verifyCertTime(equipCACertObj)) { + result = false; + break; + } + + if (!verifyCertCommonName(equipCACertObj, equipCertObj)) { + MS_LOG(ERROR) << "equip ca cert subject cn is not equal with equip cert issuer cn."; + result = false; + break; + } + + if (!verifyCertCommonName(rootFirstCA, equipCACertObj) && !verifyCertCommonName(rootSecondCA, equipCACertObj)) { + MS_LOG(ERROR) << "root CA cert subject cn is not equal with equip CA cert issuer cn."; + result = false; + break; + } + + if (!verifyExtendedAttributes(equipCACertObj)) { + MS_LOG(ERROR) << "verify equipCACert Extended Attributes failed."; + result = false; + break; + } + + if (!verifyCertKeyID(rootFirstCA, equipCACertObj) && !verifyCertKeyID(rootSecondCA, equipCACertObj)) { + MS_LOG(ERROR) << "root CA cert subject keyid is not equal with equip CA cert issuer keyid."; + result = false; + break; + } + + if (!verifyCertKeyID(equipCACertObj, equipCertObj)) { + MS_LOG(ERROR) << "equip CA cert subject keyid is not equal with equip cert issuer keyid."; + result = false; + break; + } + + if (!verifyPublicKey(keyAttestationCertObj, equipCertObj, equipCACertObj, rootFirstCA, rootSecondCA)) { + MS_LOG(ERROR) << "verify Public Key failed"; + result = false; + break; + } + } while (0); + X509_free(rootFirstCA); + X509_free(rootSecondCA); + X509_free(keyAttestationCertObj); + X509_free(equipCertObj); + X509_free(equipCACertObj); + MS_LOG(INFO) << "verifyCAChain success."; + return result; +} + +bool CertVerify::verifyCertKeyID(const X509 *caCert, const X509 *subCert) const { + bool result = true; + ASN1_OCTET_STRING *skid = nullptr; + AUTHORITY_KEYID *akeyid = nullptr; + do { + int crit = 0; + skid = reinterpret_cast(X509_get_ext_d2i(caCert, NID_subject_key_identifier, &crit, NULL)); + if (skid == nullptr) { + result = false; + break; + } + char subject_keyid[512] = {0}; + for (int i = 0; i < skid->length; i++) { + char keyid[8] = {0}; + int base = 512; + (void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)skid->data[i]); + int ret = strcat_s(subject_keyid, base, keyid); + if (ret == -1) { + result = false; + break; + } + } + + akeyid = reinterpret_cast(X509_get_ext_d2i(subCert, NID_authority_key_identifier, &crit, NULL)); + if (akeyid == nullptr) { + result = false; + break; + } + char issuer_keyid[512] = {0}; + if (akeyid->keyid == nullptr) { + MS_LOG(ERROR) << "keyid is nullprt."; + result = false; + break; + } + for (int i = 0; i < akeyid->keyid->length; i++) { + char keyid[8] = {0}; + int base = 512; + (void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)(akeyid->keyid->data[i])); + int ret = strcat_s(issuer_keyid, base, keyid); + if (ret == -1) { + result = false; + break; + } + } + + std::string subject_keyid_str = subject_keyid; + std::string issuer_keyid_str = issuer_keyid; + if (subject_keyid_str != issuer_keyid_str) { + result = false; + break; + } + } while (0); + ASN1_OCTET_STRING_free(skid); + AUTHORITY_KEYID_free(akeyid); + return result; +} + +bool CertVerify::verifyExtendedAttributes(const X509 *cert) const { + bool result = true; + BASIC_CONSTRAINTS *bcons = nullptr; + ASN1_BIT_STRING *lASN1UsageStr = nullptr; + do { + int cirt = 0; + bcons = reinterpret_cast(X509_get_ext_d2i(cert, NID_basic_constraints, &cirt, NULL)); + if (bcons == nullptr) { + result = false; + break; + } + if (!bcons->ca) { + MS_LOG(ERROR) << "Subject Type is End Entity."; + result = false; + break; + } + MS_LOG(INFO) << "Subject Type is CA."; + + lASN1UsageStr = reinterpret_cast(X509_get_ext_d2i(cert, NID_key_usage, NULL, NULL)); + if (lASN1UsageStr == nullptr) { + result = false; + break; + } + int16_t usage = lASN1UsageStr->data[0]; + if (lASN1UsageStr->length > 1) { + const unsigned int move = 8; + usage |= lASN1UsageStr->data[1] << move; + } + + if (!(usage & KU_KEY_CERT_SIGN)) { + MS_LOG(ERROR) << "Subject is not Certificate Signature."; + result = false; + break; + } + MS_LOG(INFO) << "Subject is Certificate Signature."; + } while (0); + BASIC_CONSTRAINTS_free(bcons); + ASN1_BIT_STRING_free(lASN1UsageStr); + return result; +} + +bool CertVerify::verifyCertCommonName(const X509 *caCert, const X509 *subCert) const { + if (caCert == nullptr || subCert == nullptr) { + return false; + } + + char caSubjectCN[256] = ""; + char subIssuerCN[256] = ""; + + X509_NAME *caSubjectX509CN = X509_get_subject_name(caCert); + X509_NAME *subIssuerX509CN = X509_get_issuer_name(subCert); + + int ret = X509_NAME_get_text_by_NID(caSubjectX509CN, NID_commonName, caSubjectCN, sizeof(caSubjectCN)); + if (ret < 0) { + return false; + } + ret = X509_NAME_get_text_by_NID(subIssuerX509CN, NID_commonName, subIssuerCN, sizeof(subIssuerCN)); + if (ret < 0) { + return false; + } + + std::string caSubjectCNStr = caSubjectCN; + std::string subIssuerCNStr = subIssuerCN; + + if (caSubjectCNStr != subIssuerCNStr) { + return false; + } + return true; +} + +bool CertVerify::verifyCRL(const std::string &equipCert, const std::string &equipCrlPath) { + if (!checkFileExists(equipCrlPath)) { + return true; + } + bool result = true; + X509_CRL *equipCrl = nullptr; + X509 *equipCertObj = nullptr; + EVP_PKEY *evp_pkey = nullptr; + do { + equipCrl = CertVerify::readCrlFromFile(equipCrlPath); + equipCertObj = readCertFromPerm(equipCert); + if (equipCertObj == nullptr) { + result = false; + break; + } + + if (equipCrl == nullptr) { + MS_LOG(INFO) << "equipCrl is nullptr. return true."; + result = true; + break; + } + evp_pkey = X509_get_pubkey(equipCertObj); + int ret = X509_CRL_verify(equipCrl, evp_pkey); + if (ret == 1) { + MS_LOG(ERROR) << "equip cert in equip crl, verify failed"; + result = false; + break; + } + } while (0); + + EVP_PKEY_free(evp_pkey); + X509_free(equipCertObj); + X509_CRL_free(equipCrl); + MS_LOG(INFO) << "verifyCRL success."; + return result; +} + +bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const unsigned char *signData, const std::string &flID, + const std::string &timeStamp) { + if (keyAttestation.empty() || signData == nullptr || flID.empty() || timeStamp.empty()) { + MS_LOG(ERROR) << "keyAttestation or signData or flID or timeStamp is empty."; + return false; + } + bool result = true; + X509 *keyAttestationCertObj = nullptr; + EVP_PKEY *pubKey = nullptr; + do { + keyAttestationCertObj = readCertFromPerm(keyAttestation); + + std::string srcData = flID + " " + timeStamp; + // SHA256_DIGEST_LENGTH is 32 + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + sha256Hash(srcData, srcDataHash, SHA256_DIGEST_LENGTH); + + pubKey = X509_get_pubkey(keyAttestationCertObj); + RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey); + if (pRSAPublicKey == nullptr) { + MS_LOG(ERROR) << "get rsa public key failed."; + result = false; + break; + } + + int pubKeyLen = RSA_size(pRSAPublicKey); + unsigned char buffer[256]; + int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING); + if (ret == -1) { + MS_LOG(ERROR) << "rsa public decrypt failed."; + result = false; + break; + } + + int saltLen = -2; + ret = RSA_verify_PKCS1_PSS(pRSAPublicKey, srcDataHash, EVP_sha256(), buffer, saltLen); + if (ret != 1) { + int64_t ulErr = SizeToLong(ERR_get_error()); + char szErrMsg[1024] = {0}; + MS_LOG(ERROR) << "verify error. error number: " << ulErr; + std::string str_res = ERR_error_string(ulErr, szErrMsg); + MS_LOG(ERROR) << szErrMsg; + if (str_res.empty()) { + result = false; + break; + } + result = false; + break; + } + } while (0); + EVP_PKEY_free(pubKey); + X509_free(keyAttestationCertObj); + CRYPTO_cleanup_all_ex_data(); + + MS_LOG(INFO) << "verifyRSAKey success."; + return result; +} + +void CertVerify::sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash, const int len) const { + if (len <= 0) { + return; + } + SHA256_CTX sha_ctx; + int ret = SHA256_Init(&sha_ctx); + if (ret != 1) { + return; + } + ret = SHA256_Update(&sha_ctx, src, src_len); + if (ret != 1) { + return; + } + ret = SHA256_Final(hash, &sha_ctx); + if (ret != 1) { + return; + } +} + +std::string CertVerify::toHexString(const unsigned char *data, const int len) { + if (data == nullptr) { + MS_LOG(ERROR) << "data hash is null."; + return ""; + } + + if (len <= 0) { + return ""; + } + std::stringstream ss; + int base = 2; + for (int i = 0; i < len; i++) { + ss << std::hex << std::setw(base) << std::setfill('0') << static_cast(data[i]); + } + return ss.str(); +} + +bool CertVerify::verifyEquipCertAndFlID(const std::string &flID, const std::string &equipCert) { + unsigned char hash[SHA256_DIGEST_LENGTH] = {""}; + sha256Hash(equipCert, hash, SHA256_DIGEST_LENGTH); + std::string equipCertSha256 = toHexString(hash, SHA256_DIGEST_LENGTH); + if (flID == equipCertSha256) { + MS_LOG(INFO) << "verifyEquipCertAndFlID success."; + return true; + } else { + MS_LOG(ERROR) << "verifyEquipCertAndFlID failed."; + return false; + } +} + +bool CertVerify::verifyTimeStamp(const std::string &flID, const std::string &timeStamp) const { + int64_t requestTime = std::stoll(timeStamp.c_str()); + const int64_t base = 1000; + struct timeval tv {}; + int ret = gettimeofday(&tv, nullptr); + if (ret != 0) { + return false; + } + int64_t now = tv.tv_sec * base + tv.tv_usec / base; + MS_LOG(INFO) << "flID: " << flID.c_str() << ",now time: " << now << ",requestTime: " << requestTime; + + int64_t diff = now - requestTime; + if (diff > replayAttackTimeDiff || diff < 0) { + return false; + } + MS_LOG(INFO) << "verifyTimeStamp success."; + return true; +} + +void CertVerify::sha256Hash(const std::string &src, uint8_t *hash, const int len) const { + if (len <= 0) { + return; + } + SHA256_CTX sha_ctx; + int ret = SHA256_Init(&sha_ctx); + if (ret != 1) { + return; + } + ret = SHA256_Update(&sha_ctx, src.c_str(), src.size()); + if (ret != 1) { + return; + } + ret = SHA256_Final(hash, &sha_ctx); + if (ret != 1) { + return; + } +} + +bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData, + int srcDataLen) { + if (keyAttestation.empty() || signData == nullptr || srcData == nullptr || srcDataLen <= 0) { + MS_LOG(ERROR) << "keyAttestation or signData or srcData is invalid."; + return false; + } + bool result = true; + X509 *keyAttestationCertObj = nullptr; + EVP_PKEY *pubKey = nullptr; + do { + keyAttestationCertObj = readCertFromPerm(keyAttestation); + pubKey = X509_get_pubkey(keyAttestationCertObj); + RSA *pRSAPublicKey = EVP_PKEY_get0_RSA(pubKey); + if (pRSAPublicKey == nullptr) { + MS_LOG(ERROR) << "get rsa public key failed."; + result = false; + break; + } + + int pubKeyLen = RSA_size(pRSAPublicKey); + unsigned char buffer[256]; + int ret = RSA_public_decrypt(pubKeyLen, signData, buffer, pRSAPublicKey, RSA_NO_PADDING); + if (ret == -1) { + MS_LOG(ERROR) << "rsa public decrypt failed."; + result = false; + break; + } + + int saltLen = -2; + ret = RSA_verify_PKCS1_PSS(pRSAPublicKey, srcData, EVP_sha256(), buffer, saltLen); + if (ret != 1) { + int64_t ulErr = SizeToLong(ERR_get_error()); + char szErrMsg[1024] = {0}; + MS_LOG(ERROR) << "verify error. error number: " << ulErr; + std::string str_res = ERR_error_string(ulErr, szErrMsg); + MS_LOG(ERROR) << szErrMsg; + if (str_res.empty()) { + result = false; + break; + } + result = false; + break; + } + } while (0); + EVP_PKEY_free(pubKey); + X509_free(keyAttestationCertObj); + CRYPTO_cleanup_all_ex_data(); + + MS_LOG(INFO) << "verifyRSAKey success."; + return result; +} + +bool CertVerify::initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath, + const std::string equipCrlPath, const uint64_t replay_attack_time_diff) { + if (rootFirstCaFilePath.empty() || rootSecondCaFilePath.empty()) { + MS_LOG(ERROR) << "the root or crl path is empty."; + return false; + } + + if (!checkFileExists(rootFirstCaFilePath)) { + MS_LOG(ERROR) << "The rootFirstCaFilePath is not exist."; + return false; + } + if (!checkFileExists(rootSecondCaFilePath)) { + MS_LOG(ERROR) << "The rootSecondCaFilePath is not exist."; + return false; + } + replayAttackTimeDiff = UlongToLong(replay_attack_time_diff); + return true; +} + +bool CertVerify::verifyCertAndSign(const std::string &flID, const std::string &timeStamp, const unsigned char *signData, + const std::string &keyAttestation, const std::string &equipCert, + const std::string &equipCACert, const std::string &rootFirstCAPath, + const std::string &rootSecondCAPath, const std::string &equipCrlPath) { + if (!verifyEquipCertAndFlID(flID, equipCert)) { + return false; + } + + if (!verifyCAChain(keyAttestation, equipCert, equipCACert, rootFirstCAPath, rootSecondCAPath)) { + return false; + } + + if (!verifyCRL(equipCert, equipCrlPath)) { + return false; + } + + if (!verifyRSAKey(keyAttestation, signData, flID, timeStamp)) { + return false; + } + + if (!verifyTimeStamp(flID, timeStamp)) { + return false; + } + return true; +} +#else +bool CertVerify::verifyTimeStamp(const std::string &flID, const std::string &timeStamp) const { + MS_LOG(WARNING) << "verifyTimeStamp in win32 platform."; + return false; +} +void CertVerify::sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash, const int len) const { + MS_LOG(WARNING) << "sha256Hash in win32 platform."; +} +bool CertVerify::verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData, + int srcDataLen) { + MS_LOG(WARNING) << "verifyRSAKey in win32 platform."; + return false; +} +bool CertVerify::initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath, + const std::string equipCrlPath, const uint64_t replay_attack_time_diff) { + MS_LOG(WARNING) << "initRootCertAndCRL in win32 platform."; + return false; +} +bool CertVerify::verifyCertAndSign(const std::string &flID, const std::string &timeStamp, const unsigned char *signData, + const std::string &keyAttestation, const std::string &equipCert, + const std::string &equipCACert, const std::string &rootFirstCAPath, + const std::string &rootSecondCAPath, const std::string &equipCrlPath) { + MS_LOG(WARNING) << "verifyCertAndSign in win32 platform."; + return false; +} +#endif +} // namespace server +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/cert_verify.h b/mindspore/ccsrc/fl/server/cert_verify.h new file mode 100644 index 00000000000..162018082a9 --- /dev/null +++ b/mindspore/ccsrc/fl/server/cert_verify.h @@ -0,0 +1,105 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FL_SERVER_CERT_VERIFY_H +#define MINDSPORE_CCSRC_FL_SERVER_CERT_VERIFY_H + +#include +#ifndef _WIN32 +#include +#include +#include +#include +#include +#include +#endif +#include +#include +#include +#include "utils/log_adapter.h" +#include "fl/server/common.h" + +namespace mindspore { +namespace ps { +namespace server { +class CertVerify { + public: + CertVerify() {} + ~CertVerify() = default; + + bool verifyCertAndSign(const std::string &flID, const std::string &timeStamp, const unsigned char *signData, + const std::string &keyAttestation, const std::string &equipCert, + const std::string &equipCACert, const std::string &rootFirstCAPath, + const std::string &rootSecondCAPath, const std::string &equipCrlPath); + + static bool initRootCertAndCRL(const std::string rootFirstCaFilePath, const std::string rootSecondCaFilePath, + const std::string equipCrlPath, uint64_t replay_attack_time_diff_); + + // verify valid of sign data + bool verifyRSAKey(const std::string &keyAttestation, const uint8_t *srcData, const uint8_t *signData, int srcDataLen); + + void sha256Hash(const uint8_t *src, const int src_len, uint8_t *hash, const int len) const; + + // verify valid of time stamp of request + bool verifyTimeStamp(const std::string &flID, const std::string &timeStamp) const; + +#ifndef _WIN32 + + private: + // read certificate from file path + static X509 *readCertFromFile(const std::string &certPath); + + // read Certificate Revocation List from file absolute path + static X509_CRL *readCrlFromFile(const std::string &crlPath); + + // read certificate from pem string + X509 *readCertFromPerm(std::string cert); + + // verify valid of certificate time + bool verifyCertTime(const X509 *cert) const; + + // verify valid of certificate chain + bool verifyCAChain(const std::string &keyAttestation, const std::string &equipCert, const std::string &equipCACert, + const std::string &rootFirstCAPath, const std::string &rootSecondCAPath); + + // verify valid of sign data + bool verifyRSAKey(const std::string &keyAttestation, const unsigned char *signData, const std::string &flID, + const std::string &timeStamp); + + // verify valid of equip certificate with CRL + bool verifyCRL(const std::string &equipCert, const std::string &equipCrlPath); + + // verify valid of flID with sha256(equip cert) + bool verifyEquipCertAndFlID(const std::string &flID, const std::string &equipCert); + + void sha256Hash(const std::string &src, uint8_t *hash, const int len) const; + + std::string toHexString(const unsigned char *data, const int len); + + bool verifyCertCommonName(const X509 *caCert, const X509 *subCert) const; + + bool verifyExtendedAttributes(const X509 *cert) const; + + bool verifyCertKeyID(const X509 *caCert, const X509 *subCert) const; + + bool verifyPublicKey(const X509 *keyAttestationCertObj, const X509 *equipCertObj, const X509 *equipCACertObj, + const X509 *rootFirstCA, const X509 *rootSecondCA) const; +#endif +}; +} // namespace server +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FL_SERVER_CERT_VERIFY_H diff --git a/mindspore/ccsrc/fl/server/collective_ops_impl.cc b/mindspore/ccsrc/fl/server/collective_ops_impl.cc index 870fa0222bb..206e6a4d85e 100644 --- a/mindspore/ccsrc/fl/server/collective_ops_impl.cc +++ b/mindspore/ccsrc/fl/server/collective_ops_impl.cc @@ -93,7 +93,6 @@ bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; return false; } - // Step 3: Reduce the data so we can overlap the time cost of send. for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) { recv_chunk[j] += tmp_recv_chunk[j]; @@ -164,9 +163,9 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec for (uint32_t i = 1; i < rank_size; i++) { std::shared_ptr> recv_str; MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i; - auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str); - if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { - MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; + auto recv_req_id1 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, i, &recv_str); + if (!server_node_->CollectiveWait(recv_req_id1, kCollectiveCommTimeout)) { + MS_LOG(ERROR) << "CollectiveWait " << recv_req_id1 << " failed."; return false; } ret = memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size()); @@ -180,9 +179,9 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec } } else { MS_LOG(DEBUG) << "Reduce send data to rank 0 process."; - auto send_req_id = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T)); - if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { - MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; + auto send_req_id1 = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T)); + if (!server_node_->Wait(send_req_id1, kCollectiveCommTimeout)) { + MS_LOG(ERROR) << "CollectiveWait " << send_req_id1 << " failed."; return false; } } @@ -193,19 +192,19 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec if (rank_id_ == 0) { for (uint32_t i = 1; i < rank_size; i++) { MS_LOG(DEBUG) << "Broadcast data to process " << i; - auto send_req_id = + auto send_req_id2 = server_node_->CollectiveSendAsync(ps::core::NodeRole::SERVER, i, output_buff, count * sizeof(T)); - if (!server_node_->Wait(send_req_id, kCollectiveCommTimeout)) { - MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; + if (!server_node_->Wait(send_req_id2, kCollectiveCommTimeout)) { + MS_LOG(ERROR) << "CollectiveWait " << send_req_id2 << " failed."; return false; } } } else { MS_LOG(DEBUG) << "Broadcast receive from rank 0."; std::shared_ptr> recv_str; - auto recv_req_id = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str); - if (!server_node_->CollectiveWait(recv_req_id, kCollectiveCommTimeout)) { - MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; + auto recv_req_id2 = server_node_->CollectiveReceiveAsync(ps::core::NodeRole::SERVER, 0, &recv_str); + if (!server_node_->CollectiveWait(recv_req_id2, kCollectiveCommTimeout)) { + MS_LOG(ERROR) << "CollectiveWait " << recv_req_id2 << " failed."; return false; } ret = memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size()); @@ -219,7 +218,7 @@ bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *rec } template -bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size_t send_count) { +bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *const recvbuff, size_t send_count) { MS_ERROR_IF_NULL_W_RET_VAL(node_, false); MS_ERROR_IF_NULL_W_RET_VAL(sendbuff, false); MS_ERROR_IF_NULL_W_RET_VAL(recvbuff, false); @@ -230,8 +229,8 @@ bool CollectiveOpsImpl::RingAllGather(const void *sendbuff, void *recvbuff, size // Store offsets to get every data chunk's address. std::vector chunk_offset; for (size_t i = 0; i < rank_size_; i++) { - size_t ofs = - std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast(0), std::plus()); + size_t ofs = std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + SizeToLong(i), static_cast(0), + std::plus()); chunk_offset.push_back(ofs); } @@ -295,7 +294,7 @@ bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t c MS_LOG(ERROR) << "The group is empty."; return false; } - uint32_t group_rank_size = group_info.group_ranks.size(); + uint32_t group_rank_size = SizeToUint(group_info.group_ranks.size()); uint32_t global_root_rank = group_to_global_ranks[root]; // Broadcast data to processes which are not the root. @@ -349,7 +348,7 @@ bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t c } template -bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t send_count, +bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *const recvbuff, size_t send_count, const std::shared_ptr &node) { std::unique_lock lock(mtx_); MS_ERROR_IF_NULL_W_RET_VAL(node, false); @@ -362,10 +361,10 @@ bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t s rank_id_ = node_->rank_id(); switch (node_role_) { case ps::core::WORKER: - rank_size_ = node_->worker_num(); + rank_size_ = IntToUint(node_->worker_num()); break; case ps::core::SERVER: - rank_size_ = node_->server_num(); + rank_size_ = IntToUint(node_->server_num()); break; default: MS_LOG(ERROR) << "The node role " << node_role_ << " for collective communication is invalid."; @@ -380,7 +379,7 @@ bool CollectiveOpsImpl::AllGather(const void *sendbuff, void *recvbuff, size_t s } template -bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *recvbuff, size_t count, uint32_t root, +bool CollectiveOpsImpl::Broadcast(const void *sendbuff, void *const recvbuff, size_t count, uint32_t root, const std::shared_ptr &node, const CommunicationGroupInfo &group_info) { std::unique_lock lock(mtx_); diff --git a/mindspore/ccsrc/fl/server/common.h b/mindspore/ccsrc/fl/server/common.h index 8541077c254..dbc69e49564 100644 --- a/mindspore/ccsrc/fl/server/common.h +++ b/mindspore/ccsrc/fl/server/common.h @@ -45,17 +45,17 @@ enum CommType { HTTP = 0, TCP }; enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; struct RoundConfig { - // The name of round. Please refer to round kernel *.cc files. + // The name of the round. Please refer to round kernel *.cc files. std::string name; // Whether this round has the time window limit. bool check_timeout = false; // The length of the time window. Only used when check_timeout is set to true. size_t time_window = 3000; - // Whether this round has to check the request count has reached the threshold. + // Whether this round has to check the request count has reach the threshold. bool check_count = false; - // This round's request threshold count. Only used when check_count is set to true. + // This round's request threshold count. Only used when threshold_count is set to true. size_t threshold_count = 0; - // Whether this round uses the server as threshold count. This is vital for some rounds in elastic scaling scenario. + // Whether this round uses the server number as threshold. This is vital for some rounds in elastic scaling scenario. bool server_num_as_threshold = false; }; @@ -67,6 +67,8 @@ struct CipherConfig { size_t share_secrets_threshold = 0; size_t get_secrets_threshold = 0; size_t client_list_threshold = 0; + size_t push_list_sign_threshold = 0; + size_t get_list_sign_threshold = 0; size_t reconstruct_secrets_threshold = 0; }; @@ -130,7 +132,6 @@ constexpr auto kAdamEps = "eps"; constexpr auto kFtrlLinear = "linear"; constexpr auto kDataSize = "data_size"; constexpr auto kNewDataSize = "new_data_size"; -constexpr auto kStat = "stat"; // OptimParamNameToIndex represents every inputs/workspace/outputs parameter's offset when an optimizer kernel is // launched. @@ -175,14 +176,11 @@ const OptimParamNameToIndex kAdamWeightDecayNameToIdx = {{"inputs", {"weight_decay", 7}, {"grad", 8}}}, {"outputs", {}}}; -const OptimParamNameToIndex kSGDNameToIdx = { - {"inputs", {{kWeight, 0}, {kGradient, 1}, {kLearningRate, 2}, {kAccumulation, 3}, {kMomentum, 4}, {kStat, 5}}}, - {"outputs", {}}}; - -const std::map kNameToIdxMap = { - {kApplyMomentumOpName, kMomentumNameToIdx}, {kFusedSparseAdamName, kSparseAdamNameToIdx}, - {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, {kApplyAdamOpName, kAdamNameToIdx}, - {"AdamWeightDecay", kAdamWeightDecayNameToIdx}, {kSGDName, kSGDNameToIdx}}; +const std::map kNameToIdxMap = {{kApplyMomentumOpName, kMomentumNameToIdx}, + {kFusedSparseAdamName, kSparseAdamNameToIdx}, + {kSparseApplyFtrlOpName, kSparseFtrlNameToIdx}, + {kApplyAdamOpName, kAdamNameToIdx}, + {"AdamWeightDecay", kAdamWeightDecayNameToIdx}}; constexpr uint32_t kLeaderServerRank = 0; constexpr size_t kWorkerMgrThreadPoolSize = 32; @@ -215,9 +213,12 @@ constexpr auto kCtxGetSecretsClientList = "get_secrets_client_list"; constexpr auto kCtxReconstructClientList = "reconstruct_client_list"; constexpr auto kCtxExChangeKeysClientList = "exchange_keys_client_list"; constexpr auto kCtxGetUpdateModelClientList = "get_update_model_client_list"; +constexpr auto kCtxClientListSigns = "client_list_signs"; +constexpr auto kCtxClientKeyAttestation = "client_key_attestation"; constexpr auto kCtxGetKeysClientList = "get_keys_client_list"; constexpr auto kCtxFedAvgTotalDataSize = "fed_avg_total_data_size"; constexpr auto kCtxCipherPrimer = "cipher_primer"; +constexpr auto kCurrentIteration = "current_iteration"; // This macro the current timestamp in milliseconds. #define CURRENT_TIME_MILLI \ @@ -252,6 +253,14 @@ inline AddressPtr GenerateParameterNodeAddrPtr(const CNodePtr &kernel_node, size return addr; } +template +inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) { + if (!json.contains(key)) { + MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump(); + } + return json[key].get(); +} + // Definitions for Federated Learning. constexpr auto kNetworkError = "Cluster networking failed."; diff --git a/mindspore/ccsrc/fl/server/consistent_hash_ring.cc b/mindspore/ccsrc/fl/server/consistent_hash_ring.cc index 52aee7ab496..962177ba83d 100644 --- a/mindspore/ccsrc/fl/server/consistent_hash_ring.cc +++ b/mindspore/ccsrc/fl/server/consistent_hash_ring.cc @@ -26,7 +26,7 @@ bool ConsistentHashRing::Insert(uint32_t rank) { MS_LOG(DEBUG) << "Insert virtual node " << physical_node_hash_key << " for node " << rank << ", hash value is " << hash_value; if (ring_.count(hash_value) != 0) { - MS_LOG(INFO) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring."; + MS_LOG(WARNING) << "Virtual node " << physical_node_hash_key << " is already mapped to the ring."; continue; } ring_[hash_value] = rank; @@ -37,7 +37,7 @@ bool ConsistentHashRing::Insert(uint32_t rank) { bool ConsistentHashRing::Erase(uint32_t rank) { for (auto iterator = ring_.begin(); iterator != ring_.end();) { if (iterator->second == rank) { - iterator = ring_.erase(iterator); + (void)ring_.erase(iterator++); } else { ++iterator; } diff --git a/mindspore/ccsrc/fl/server/distributed_count_service.cc b/mindspore/ccsrc/fl/server/distributed_count_service.cc index f143d2427c5..69c18f47e07 100644 --- a/mindspore/ccsrc/fl/server/distributed_count_service.cc +++ b/mindspore/ccsrc/fl/server/distributed_count_service.cc @@ -329,6 +329,7 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st // Broadcast to all follower servers. for (uint32_t i = 1; i < server_num_; i++) { + MS_LOG(INFO) << "Start sending first count event message to server " << i; if (!communicator_->SendPbRequest(first_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; if (reason != nullptr) { @@ -343,7 +344,9 @@ bool DistributedCountService::TriggerFirstCountEvent(const std::string &name, st return false; } // Leader server directly calls the callback. + MS_LOG(INFO) << "Leader server call first count handler for " << name << "..."; counter_handlers_[name].first_count_handler(nullptr); + MS_LOG(INFO) << "First count handler for " << name << " is successfully called."; return true; } @@ -355,6 +358,7 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std // Broadcast to all follower servers. for (uint32_t i = 1; i < server_num_; i++) { + MS_LOG(INFO) << "Start sending last count event message to server " << i; if (!communicator_->SendPbRequest(last_count_event, i, ps::core::TcpUserCommand::kCounterEvent)) { MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; if (reason != nullptr) { @@ -369,7 +373,9 @@ bool DistributedCountService::TriggerLastCountEvent(const std::string &name, std return false; } // Leader server directly calls the callback. + MS_LOG(INFO) << "Leader server call last count handler for " << name << "..."; counter_handlers_[name].last_count_handler(nullptr); + MS_LOG(INFO) << "Last count handler for " << name << " is successfully called."; return true; } } // namespace server diff --git a/mindspore/ccsrc/fl/server/distributed_count_service.h b/mindspore/ccsrc/fl/server/distributed_count_service.h index 60575d5aca0..d98f2e9f195 100644 --- a/mindspore/ccsrc/fl/server/distributed_count_service.h +++ b/mindspore/ccsrc/fl/server/distributed_count_service.h @@ -21,7 +21,6 @@ #include #include #include -#include "utils/hash_map.h" #include "proto/ps.pb.h" #include "fl/server/common.h" #include "ps/core/server_node.h" @@ -118,14 +117,14 @@ class DistributedCountService { // Key: name, e.g, startFLJob, updateModel, push. // Value: a set of id without repeatation because each work may report multiple times. - mindspore::HashMap> global_current_count_; + std::unordered_map> global_current_count_; // Key: name, e.g, StartFLJobCount. // Value: global threshold count in the server cluster dimension for this name. - mindspore::HashMap global_threshold_count_; + std::unordered_map global_threshold_count_; // First/last count event callbacks of the name. - mindspore::HashMap counter_handlers_; + std::unordered_map counter_handlers_; // Because the count is increased/queried conccurently, we must ensure the operations are threadsafe. std::unordered_map mutex_; diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc index b0b80c1331a..4ea1238f8da 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.cc +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.cc @@ -125,10 +125,6 @@ PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; return {}; } - if (metadata_.count(name) == 0) { - MS_LOG(ERROR) << "The metadata of " << name << " is not registered."; - return {}; - } uint32_t stored_rank = router_->Find(name); MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank; if (local_rank_ == stored_rank) { diff --git a/mindspore/ccsrc/fl/server/distributed_metadata_store.h b/mindspore/ccsrc/fl/server/distributed_metadata_store.h index d14db23d736..743ecf33913 100644 --- a/mindspore/ccsrc/fl/server/distributed_metadata_store.h +++ b/mindspore/ccsrc/fl/server/distributed_metadata_store.h @@ -20,7 +20,6 @@ #include #include #include -#include "utils/hash_map.h" #include "proto/ps.pb.h" #include "fl/server/common.h" #include "ps/core/server_node.h" @@ -107,7 +106,7 @@ class DistributedMetadataStore { // We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use. // Key: data name. // Value: ProtoBuffer Struct. - mindspore::HashMap metadata_; + std::unordered_map metadata_; // Because the metadata is read/written conccurently, we must ensure the operations are threadsafe. std::unordered_map mutex_; diff --git a/mindspore/ccsrc/fl/server/executor.cc b/mindspore/ccsrc/fl/server/executor.cc index f121f5aa3b5..42057616e61 100644 --- a/mindspore/ccsrc/fl/server/executor.cc +++ b/mindspore/ccsrc/fl/server/executor.cc @@ -65,47 +65,6 @@ bool Executor::ReInitForUpdatingHyperParams(size_t aggr_threshold) { bool Executor::initialized() const { return initialized_; } -bool Executor::HandlePush(const std::string ¶m_name, const UploadData &upload_data) { - MS_LOG(DEBUG) << "Do Push for parameter " << param_name; - if (param_aggrs_.count(param_name) == 0) { - MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; - return false; - } - - std::mutex &mtx = parameter_mutex_[param_name]; - std::unique_lock lock(mtx); - auto ¶m_aggr = param_aggrs_[param_name]; - MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); - // Push operation needs to wait until the pulling process is done. - while (!param_aggr->IsPullingDone()) { - lock.unlock(); - std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime)); - lock.lock(); - } - - // 1.Update data with the uploaded data of the worker. - if (!param_aggr->UpdateData(upload_data)) { - MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; - return false; - } - // 2.Launch aggregation for this trainable parameter. - if (!param_aggr->LaunchAggregators()) { - MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; - return false; - } - if (param_aggr->IsAggregationDone()) { - // 3.After the aggregation is done, optimize the trainable parameter. - if (!param_aggr->LaunchOptimizers()) { - MS_LOG(ERROR) << "Optimizing for parameter " << param_name << " failed."; - return false; - } - // 4.Reset pulling and aggregation status after optimizing is done. - param_aggr->ResetPullingStatus(); - param_aggr->ResetAggregationStatus(); - } - return true; -} - bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data) { MS_LOG(DEBUG) << "Do UpdateModel for parameter " << param_name; if (param_aggrs_.count(param_name) == 0) { @@ -131,32 +90,6 @@ bool Executor::HandleModelUpdate(const std::string ¶m_name, const UploadData return true; } -bool Executor::HandleModelUpdateAsync(const std::map &feature_map) { - std::unique_lock model_lock(model_mutex_); - for (const auto &trainable_param : feature_map) { - const std::string ¶m_name = trainable_param.first; - if (param_aggrs_.count(param_name) == 0) { - MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; - continue; - } - - std::mutex &mtx = parameter_mutex_[param_name]; - std::unique_lock lock(mtx); - auto ¶m_aggr = param_aggrs_[param_name]; - MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, false); - const UploadData &upload_data = trainable_param.second; - if (!param_aggr->UpdateData(upload_data)) { - MS_LOG(ERROR) << "Updating data for parameter " << param_name << " failed."; - return false; - } - if (!param_aggr->LaunchAggregators()) { - MS_LOG(ERROR) << "Launching aggregators for parameter " << param_name << " failed."; - return false; - } - } - return true; -} - bool Executor::HandlePushWeight(const std::map &feature_map) { for (const auto &trainable_param : feature_map) { const std::string ¶m_name = trainable_param.first; @@ -183,31 +116,6 @@ bool Executor::HandlePushWeight(const std::map &feature_ma return true; } -AddressPtr Executor::HandlePull(const std::string ¶m_name) { - MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name; - if (param_aggrs_.count(param_name) == 0) { - MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; - return nullptr; - } - - std::mutex &mtx = parameter_mutex_[param_name]; - std::unique_lock lock(mtx); - auto ¶m_aggr = param_aggrs_[param_name]; - MS_ERROR_IF_NULL_W_RET_VAL(param_aggr, nullptr); - // Pulling must wait until the optimizing process is done. - while (!param_aggr->IsOptimizingDone()) { - lock.unlock(); - std::this_thread::sleep_for(std::chrono::milliseconds(kThreadSleepTime)); - lock.lock(); - } - AddressPtr addr = param_aggr->Pull(); - // If this Pull is the last one, reset pulling and optimizing status. - if (param_aggr->IsPullingDone()) { - param_aggr->ResetOptimizingStatus(); - } - return addr; -} - std::map Executor::HandlePullWeight(const std::vector ¶m_names) { std::map weights; for (const auto ¶m_name : param_names) { @@ -298,7 +206,7 @@ bool Executor::unmasked() const { if (encrypt_type == ps::kPWEncryptType) { return unmasked_.load(); } else { - // If the algorithm of pairwise encrypt is not enabled, consider_ unmasked flag as true. + // If the algorithm of mind armour is not enabled, consider unmasked_ flag as true. return true; } } @@ -340,7 +248,7 @@ bool Executor::InitParamAggregator(const FuncGraphPtr &func_graph) { param_aggrs_[param_name] = param_aggr; parameter_mutex_[param_name]; if (!param_aggr->Init(cnode, aggregation_count_)) { - MS_LOG(EXCEPTION) << "Initializing parameter aggregator for " << param_name << " failed."; + MS_LOG(EXCEPTION) << "Initializing parameter aggregator for param_name " << param_name << " failed."; return false; } MS_LOG(DEBUG) << "Initializing parameter aggregator for param_name " << param_name << " success."; diff --git a/mindspore/ccsrc/fl/server/executor.h b/mindspore/ccsrc/fl/server/executor.h index 3bc90288d5f..428a4482e11 100644 --- a/mindspore/ccsrc/fl/server/executor.h +++ b/mindspore/ccsrc/fl/server/executor.h @@ -24,11 +24,11 @@ #include #include #include -#include "fl/server/common.h" -#include "fl/server/parameter_aggregator.h" #ifdef ENABLE_ARMOUR #include "fl/armour/cipher/cipher_unmask.h" #endif +#include "fl/server/common.h" +#include "fl/server/parameter_aggregator.h" namespace mindspore { namespace fl { @@ -54,28 +54,13 @@ class Executor { // After hyper-parameters are updated, some parameter aggregators should be reinitialized. bool ReInitForUpdatingHyperParams(size_t aggr_threshold); - // Called in parameter server training mode to do Push operation. - // For the same trainable parameter, HandlePush method must be called aggregation_count_ times before it's considered - // as completed. - bool HandlePush(const std::string ¶m_name, const UploadData &upload_data); - - // Called in parameter server training mode to do Pull operation. - // Returns the value of parameter param_name. - // HandlePull method must be called the same times as HandlePush is called before it's considered as - // completed. - AddressPtr HandlePull(const std::string ¶m_name); - // Called in federated learning training mode. Update value for parameter param_name. bool HandleModelUpdate(const std::string ¶m_name, const UploadData &upload_data); - // Called in asynchronous federated learning training mode. Update current model with the new feature map - // asynchronously. - bool HandleModelUpdateAsync(const std::map &feature_map); - - // Overwrite the weights in server using pushed feature map. + // Forcibly overwrite specific weights in overwriteWeights message. bool HandlePushWeight(const std::map &feature_map); - // Returns multiple trainable parameters passed by weight_names. + // Returns value for multiple trainable parameters passed by weight_names. std::map HandlePullWeight(const std::vector ¶m_names); // Reset the aggregation status for all aggregation kernels in the server. @@ -135,7 +120,7 @@ class Executor { armour::CipherUnmask cipher_unmask_; #endif - // The flag represents the unmasking status. + // The flag refers to the unmasking status std::atomic unmasked_; }; } // namespace server diff --git a/mindspore/ccsrc/fl/server/iteration.cc b/mindspore/ccsrc/fl/server/iteration.cc index a300ad5e451..3011225291b 100644 --- a/mindspore/ccsrc/fl/server/iteration.cc +++ b/mindspore/ccsrc/fl/server/iteration.cc @@ -16,8 +16,8 @@ #include "fl/server/iteration.h" #include -#include #include +#include #include #include "fl/server/model_store.h" #include "fl/server/server.h" @@ -38,6 +38,7 @@ Iteration::~Iteration() { void Iteration::RegisterMessageCallback(const std::shared_ptr &communicator) { MS_EXCEPTION_IF_NULL(communicator); communicator_ = communicator; + MS_EXCEPTION_IF_NULL(communicator_); communicator_->RegisterMsgCallBack("syncIteration", std::bind(&Iteration::HandleSyncIterationRequest, this, std::placeholders::_1)); communicator_->RegisterMsgCallBack( @@ -54,10 +55,10 @@ void Iteration::RegisterMessageCallback(const std::shared_ptr &server_node) { MS_EXCEPTION_IF_NULL(server_node); server_node_ = server_node; - server_node->RegisterCustomEventCallback(static_cast(ps::CustomEvent::kIterationRunning), - std::bind(&Iteration::HandleIterationRunningEvent, this)); - server_node->RegisterCustomEventCallback(static_cast(ps::CustomEvent::kIterationCompleted), - std::bind(&Iteration::HandleIterationCompletedEvent, this)); + server_node->RegisterCustomEventCallback(static_cast(ps::UserDefineEvent::kIterationRunning), + std::bind(&Iteration::ProcessIterationRunningEvent, this)); + server_node->RegisterCustomEventCallback(static_cast(ps::UserDefineEvent::kIterationCompleted), + std::bind(&Iteration::ProcessIterationEndEvent, this)); } void Iteration::AddRound(const std::shared_ptr &round) { @@ -97,6 +98,7 @@ void Iteration::InitRounds(const std::vectorrank_id() == kLeaderServerRank) { // This event helps worker/server to be consistent in iteration state. - server_node_->BroadcastEvent(static_cast(ps::CustomEvent::kIterationRunning)); + server_node_->BroadcastEvent(static_cast(ps::UserDefineEvent::kIterationRunning)); + if (server_recovery_ != nullptr) { + // Save data to the persistent storage in case the recovery happens at the beginning. + if (!server_recovery_->Save(iteration_num_)) { + MS_LOG(WARNING) << "Save recovery data failed."; + } + } } std::unique_lock lock(iteration_state_mtx_); @@ -155,12 +164,12 @@ void Iteration::SetIterationRunning() { start_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()); } -void Iteration::SetIterationCompleted() { - MS_LOG(INFO) << "Iteration " << iteration_num_ << " completes."; +void Iteration::SetIterationEnd() { + MS_LOG(INFO) << "Iteration " << iteration_num_ << " ends."; MS_ERROR_IF_NULL_WO_RET_VAL(server_node_); if (server_node_->rank_id() == kLeaderServerRank) { // This event helps worker/server to be consistent in iteration state. - server_node_->BroadcastEvent(static_cast(ps::CustomEvent::kIterationCompleted)); + server_node_->BroadcastEvent(static_cast(ps::UserDefineEvent::kIterationCompleted)); } std::unique_lock lock(iteration_state_mtx_); @@ -274,7 +283,7 @@ bool Iteration::DisableServerInstance(std::string *result) { instance_state_ = InstanceState::kDisable; if (!ForciblyMoveToNextIteration()) { *result = "Disabling instance failed. Can't drop current iteration and move to the next."; - MS_LOG(ERROR) << *result; + MS_LOG(ERROR) << result; return false; } *result = "Disabling FL-Server succeeded."; @@ -295,6 +304,11 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string return false; } + if (iteration_num_ == 1) { + MS_LOG(INFO) << "This is just the first iteration."; + return true; + } + // Start new server instance. is_instance_being_updated_ = true; @@ -312,7 +326,7 @@ bool Iteration::NewInstance(const nlohmann::json &new_instance_json, std::string ModelStore::GetInstance().Reset(); if (metrics_ != nullptr) { if (!metrics_->Clear()) { - MS_LOG(WARNING) << "Clear metrics fil failed."; + MS_LOG(WARNING) << "Clear metrics file failed."; } } @@ -340,6 +354,16 @@ void Iteration::WaitAllRoundsFinish() const { } } +void Iteration::set_recovery_handler(const std::shared_ptr &server_recovery) { + MS_EXCEPTION_IF_NULL(server_recovery); + server_recovery_ = server_recovery; +} + +bool Iteration::SyncAfterRecovery(uint64_t) { + NotifyNext(false, "Move to next iteration after recovery."); + return true; +} + bool Iteration::SyncIteration(uint32_t rank) { MS_ERROR_IF_NULL_W_RET_VAL(communicator_, false); SyncIterationRequest sync_iter_req; @@ -348,7 +372,7 @@ bool Iteration::SyncIteration(uint32_t rank) { std::shared_ptr> sync_iter_rsp_msg = nullptr; if (!communicator_->SendPbRequest(sync_iter_req, kLeaderServerRank, ps::core::TcpUserCommand::kSyncIteration, &sync_iter_rsp_msg)) { - MS_LOG(ERROR) << "Sending synchronizing iteration message to leader server failed."; + MS_LOG(ERROR) << "Sending sync iter message to leader server failed."; return false; } @@ -356,8 +380,7 @@ bool Iteration::SyncIteration(uint32_t rank) { SyncIterationResponse sync_iter_rsp; (void)sync_iter_rsp.ParseFromArray(sync_iter_rsp_msg->data(), SizeToInt(sync_iter_rsp_msg->size())); iteration_num_ = sync_iter_rsp.iteration(); - MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is " - << sync_iter_rsp.iteration(); + MS_LOG(INFO) << "After synchronizing, server " << rank << " current iteration number is " << iteration_num_; return true; } @@ -496,7 +519,9 @@ void Iteration::HandlePrepareForNextIterRequest(const std::shared_ptrrank_id() == kLeaderServerRank) { + // Save current iteration number for recovery. + MS_ERROR_IF_NULL_WO_RET_VAL(server_recovery_); + if (!server_recovery_->Save(iteration_num_)) { + MS_LOG(WARNING) << "Can't save current iteration number into persistent storage."; + } + } + Server::GetInstance().CancelSafeMode(); iteration_state_cv_.notify_all(); MS_LOG(INFO) << "Move to next iteration:" << iteration_num_ << "\n"; diff --git a/mindspore/ccsrc/fl/server/iteration.h b/mindspore/ccsrc/fl/server/iteration.h index a5b159f8bf2..c1961669144 100644 --- a/mindspore/ccsrc/fl/server/iteration.h +++ b/mindspore/ccsrc/fl/server/iteration.h @@ -25,6 +25,7 @@ #include "fl/server/round.h" #include "fl/server/local_meta_store.h" #include "fl/server/iteration_metrics.h" +#include "fl/server/server_recovery.h" namespace mindspore { namespace fl { @@ -52,7 +53,7 @@ class Iteration { // Register callbacks for other servers to synchronize iteration information from leader server. void RegisterMessageCallback(const std::shared_ptr &communicator); - // Register event callbacks for iteration state synchronization. + // Register event callback for iteration state synchronization. void RegisterEventCallback(const std::shared_ptr &server_node); // Add a round for the iteration. This method will be called multiple times for each round. @@ -70,14 +71,14 @@ class Iteration { // This method will control servers to proceed to next iteration. // There's communication between leader and follower servers in this method. - // The server moves to next iteration only after the last round finishes or the time expires. + // The server moves to the next iteration only after the last round finishes or the timer expires. void MoveToNextIteration(bool is_last_iter_valid, const std::string &reason); - // Set current iteration state to running and trigger events about kIterationRunning. + // Set current iteration state to running and trigger the event. void SetIterationRunning(); - // Set current iteration state to completed and trigger the event about kIterationCompleted. - void SetIterationCompleted(); + // Set current iteration state to end and trigger the event. + void SetIterationEnd(); // The barrier function for elastic scaling. The scaling out/in operation should be done only after this iteration is // completed. @@ -118,6 +119,12 @@ class Iteration { // Need to wait all the rounds to finish before proceed to next iteration. void WaitAllRoundsFinish() const; + // Set server's recovery handler. + void set_recovery_handler(const std::shared_ptr &server_recovery); + + // Synchronize server iteration after another server's recovery is completed. + bool SyncAfterRecovery(uint64_t iteration_num); + // The round kernels whose Launch method has not returned yet. std::atomic_uint32_t running_round_num_; @@ -150,10 +157,10 @@ class Iteration { Iteration &operator=(const Iteration &) = delete; // The server does not need to handle the iteration events for now. - void HandleIterationRunningEvent() {} - void HandleIterationCompletedEvent() {} + void ProcessIterationRunningEvent() {} + void ProcessIterationEndEvent() {} - // Synchronize iteration form the leader server(Rank 0). + // Synchronize iteration from the leader server(Rank 0). bool SyncIteration(uint32_t rank); void HandleSyncIterationRequest(const std::shared_ptr &message); @@ -165,13 +172,13 @@ class Iteration { bool NotifyLeaderMoveToNextIteration(bool is_last_iter_valid, const std::string &reason); void HandleNotifyLeaderMoveToNextIterRequest(const std::shared_ptr &message); - // Step 2: leader server broadcast to all follower servers to prepare for next iteration and switch to safemode. + // Step 2: leader server broadcasts to all follower servers to prepare for next iteration and switch to safemode.. bool BroadcastPrepareForNextIterRequest(bool is_last_iter_valid, const std::string &reason); void HandlePrepareForNextIterRequest(const std::shared_ptr &message); // The server prepare for the next iteration. This method will switch the server to safemode. void PrepareForNextIter(); - // Step 3: leader server broadcast to all follower servers to move to next iteration. + // Step 3: leader server broadcasts to all follower servers to move to next iteration. bool BroadcastMoveToNextIterRequest(bool is_last_iter_valid, const std::string &reason); void HandleMoveToNextIterRequest(const std::shared_ptr &message); // Move to next iteration. Store last iterations model and reset all the rounds. @@ -201,6 +208,9 @@ class Iteration { // All the rounds in the server. std::vector> rounds_; + // The recovery object for server. + std::shared_ptr server_recovery_; + // The iteration is either running or completed at any time. std::mutex iteration_state_mtx_; std::condition_variable iteration_state_cv_; diff --git a/mindspore/ccsrc/fl/server/iteration_metrics.cc b/mindspore/ccsrc/fl/server/iteration_metrics.cc index 5c735afc732..46f8fdeaf96 100644 --- a/mindspore/ccsrc/fl/server/iteration_metrics.cc +++ b/mindspore/ccsrc/fl/server/iteration_metrics.cc @@ -17,6 +17,7 @@ #include "fl/server/iteration_metrics.h" #include #include +#include "utils/file_utils.h" #include "debug/common.h" #include "ps/constants.h" @@ -68,7 +69,7 @@ bool IterationMetrics::Initialize() { } bool IterationMetrics::Summarize() { - metrics_file_.open(metrics_file_path_, std::ios::ate | std::ios::out); + metrics_file_.open(metrics_file_path_, std::ios::out | std::ios::app); if (!metrics_file_.is_open()) { MS_LOG(ERROR) << "The metrics file is not opened."; return false; diff --git a/mindspore/ccsrc/fl/server/iteration_metrics.h b/mindspore/ccsrc/fl/server/iteration_metrics.h index de011ed4ae5..96c9d0d938b 100644 --- a/mindspore/ccsrc/fl/server/iteration_metrics.h +++ b/mindspore/ccsrc/fl/server/iteration_metrics.h @@ -39,20 +39,11 @@ constexpr auto kRejectedClientNum = "rejectedClientNum"; constexpr auto kMetricsAuc = "metricsAuc"; constexpr auto kMetricsLoss = "metricsLoss"; constexpr auto kIterExecutionTime = "iterationExecutionTime"; +constexpr auto kMetrics = "metrics"; const std::map kInstanceStateName = { {InstanceState::kRunning, "running"}, {InstanceState::kDisable, "disable"}, {InstanceState::kFinish, "finish"}}; -template -inline T JsonGetKeyWithException(const nlohmann::json &json, const std::string &key) { - if (!json.contains(key)) { - MS_LOG(EXCEPTION) << "The key " << key << "does not exist in json " << json.dump(); - } - return json[key].get(); -} - -constexpr auto kMetrics = "metrics"; - class IterationMetrics { public: explicit IterationMetrics(const std::string &config_file) diff --git a/mindspore/ccsrc/fl/server/iteration_timer.cc b/mindspore/ccsrc/fl/server/iteration_timer.cc index 780c2ff2f16..0f3ba5c509e 100644 --- a/mindspore/ccsrc/fl/server/iteration_timer.cc +++ b/mindspore/ccsrc/fl/server/iteration_timer.cc @@ -19,6 +19,13 @@ namespace mindspore { namespace fl { namespace server { +IterationTimer::~IterationTimer() { + running_ = false; + if (monitor_thread_.joinable()) { + monitor_thread_.join(); + } +} + void IterationTimer::Start(const std::chrono::milliseconds &duration) { if (running_.load()) { MS_LOG(WARNING) << "The timer already started."; @@ -50,7 +57,7 @@ void IterationTimer::SetTimeOutCallBack(const TimeOutCb &timeout_cb) { return; } -bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) const { +bool IterationTimer::IsTimeOut(const std::chrono::milliseconds ×tamp) { return timestamp > end_time_ ? true : false; } diff --git a/mindspore/ccsrc/fl/server/iteration_timer.h b/mindspore/ccsrc/fl/server/iteration_timer.h index 752e392f8d8..af6c3c0e1d5 100644 --- a/mindspore/ccsrc/fl/server/iteration_timer.h +++ b/mindspore/ccsrc/fl/server/iteration_timer.h @@ -30,7 +30,7 @@ namespace server { class IterationTimer { public: IterationTimer() : running_(false), end_time_(0) {} - ~IterationTimer() = default; + ~IterationTimer(); // Start timing. The timer will stop after parameter 'duration' milliseconds. void Start(const std::chrono::milliseconds &duration); @@ -42,7 +42,7 @@ class IterationTimer { void SetTimeOutCallBack(const TimeOutCb &timeout_cb); // Judge whether current timestamp is out of time window's range since the Start function is called. - bool IsTimeOut(const std::chrono::milliseconds ×tamp) const; + bool IsTimeOut(const std::chrono::milliseconds ×tamp); // Judge whether the timer is keeping timing. bool IsRunning() const; diff --git a/mindspore/ccsrc/fl/server/kernel/dense_grad_accum_kernel.h b/mindspore/ccsrc/fl/server/kernel/dense_grad_accum_kernel.h index eb3b5fd3bb8..5f45fa36e51 100644 --- a/mindspore/ccsrc/fl/server/kernel/dense_grad_accum_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/dense_grad_accum_kernel.h @@ -36,59 +36,10 @@ class DenseGradAccumKernel : public AggregationKernel { DenseGradAccumKernel() = default; ~DenseGradAccumKernel() override = default; - void InitKernel(const CNodePtr &kernel_node) override { - MS_EXCEPTION_IF_NULL(kernel_node); - std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); - if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 || - kNameToIdxMap.at(cnode_name).at("inputs").count("grad") == 0) { - MS_LOG(EXCEPTION) << "Can't find index info of grad for kernel " << cnode_name; - return; - } - size_t cnode_grad_idx = kNameToIdxMap.at(cnode_name).at("inputs").at("grad"); - std::vector grad_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, cnode_grad_idx); - size_t grad_size = std::accumulate(grad_shape.begin(), grad_shape.end(), sizeof(T), std::multiplies()); - input_size_list_.push_back(grad_size); - size_t new_grad_size = grad_size; - input_size_list_.push_back(new_grad_size); - GenerateReuseKernelNodeInfo(); - return; - } + void InitKernel(const CNodePtr &) override { return; } - bool Launch(const std::vector &inputs, const std::vector &, + bool Launch(const std::vector &, const std::vector &, const std::vector &) override { - if (inputs.size() != kDenseGradAccumKernelInputsNum) { - MS_LOG(ERROR) << "The inputs number of DenseGradAccumKernel should be 2, but got " << inputs.size(); - return false; - } - MS_ERROR_IF_NULL_W_RET_VAL(inputs[0], false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[1], false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false); - - if (accum_count_ == 0) { - int ret = memset_s(inputs[0]->addr, inputs[0]->size, 0x00, inputs[0]->size); - if (ret != 0) { - MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")"; - return false; - } - } - - T *grad_addr = reinterpret_cast(inputs[0]->addr); - T *new_grad_addr = reinterpret_cast(inputs[1]->addr); - for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) { - grad_addr[i] += new_grad_addr[i]; - } - - accum_count_++; - if (accum_count_ > done_count_) { - MS_LOG(ERROR) << "accum_count_ should not be greater than done_count_ " << done_count_; - return false; - } - if (accum_count_ == done_count_) { - for (size_t i = 0; i < inputs[0]->size / sizeof(T); i++) { - grad_addr[i] /= done_count_; - } - } return true; } diff --git a/mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h b/mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h index 5a5a4ab2f11..b379bc23790 100644 --- a/mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/fed_avg_kernel.h @@ -108,8 +108,8 @@ class FedAvgKernel : public AggregationKernel { done_ = true; return; }; - GenerateReuseKernelNodeInfo(); DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler_, last_cnt_handler_}); + GenerateReuseKernelNodeInfo(); return; } @@ -119,14 +119,9 @@ class FedAvgKernel : public AggregationKernel { MS_LOG(ERROR) << "The inputs number of FedAvgKernel should be 4, but got " << inputs.size(); return false; } - MS_ERROR_IF_NULL_W_RET_VAL(inputs[0], false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[1], false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[2], false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[3], false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[0]->addr, false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[1]->addr, false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[2]->addr, false); - MS_ERROR_IF_NULL_W_RET_VAL(inputs[3]->addr, false); + for (size_t i = 0; i < inputs.size(); i++) { + MS_ERROR_IF_NULL_W_RET_VAL(inputs[i]->addr, false); + } std::unique_lock lock(weight_mutex_); // The weight and new_weight values should be multiplied by clients already, so we don't need to do multiplication @@ -181,7 +176,7 @@ class FedAvgKernel : public AggregationKernel { bool ReInitForUpdatingHyperParams(size_t aggr_threshold) override { done_count_ = aggr_threshold; if (!DistributedCountService::GetInstance().ReInitCounter(name_, done_count_)) { - MS_LOG(ERROR) << "Reinitializing counter for " << name_ << " failed."; + MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed."; return false; } return true; diff --git a/mindspore/ccsrc/fl/server/kernel/kernel_factory.h b/mindspore/ccsrc/fl/server/kernel/kernel_factory.h index 3ae6d26cc08..7a3e6afbfa9 100644 --- a/mindspore/ccsrc/fl/server/kernel/kernel_factory.h +++ b/mindspore/ccsrc/fl/server/kernel/kernel_factory.h @@ -21,7 +21,7 @@ #include #include #include -#include "utils/hash_map.h" +#include #include "fl/server/common.h" #include "fl/server/kernel/params_info.h" @@ -83,7 +83,7 @@ class KernelFactory { // Generally, a server kernel can correspond to several ParamsInfo which is registered by the method 'Register' in // server kernel's *.cc files. - mindspore::HashMap>> name_to_creator_map_; + std::unordered_map>> name_to_creator_map_; }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/optimizer_kernel_factory.cc b/mindspore/ccsrc/fl/server/kernel/optimizer_kernel_factory.cc index 8d37ab32ba5..bfceda8de89 100644 --- a/mindspore/ccsrc/fl/server/kernel/optimizer_kernel_factory.cc +++ b/mindspore/ccsrc/fl/server/kernel/optimizer_kernel_factory.cc @@ -21,49 +21,7 @@ namespace mindspore { namespace fl { namespace server { namespace kernel { -bool OptimizerKernelFactory::Matched(const ParamsInfo ¶ms_info, const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); - if (kNameToIdxMap.count(cnode_name) == 0) { - MS_LOG(ERROR) << "Can't find index info for kernel " << cnode_name; - return false; - } - - auto input_name_to_idx = kNameToIdxMap.at(cnode_name).at("inputs"); - size_t input_num = params_info.inputs_num(); - for (size_t i = 0; i < input_num; i++) { - auto one_input_name_type = params_info.inputs_name_type(i); - std::string name = one_input_name_type.first; - if (input_name_to_idx.count(name) == 0) { - MS_LOG(EXCEPTION) << cnode_name << " does not have input named " << name; - return false; - } - size_t input_idx = input_name_to_idx.at(name); - TypeId kernel_node_input_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_idx); - TypeId registered_input_type = one_input_name_type.second; - if (registered_input_type != kernel_node_input_type) { - return false; - } - } - - auto output_name_to_idx = kNameToIdxMap.at(cnode_name).at("outputs"); - size_t output_num = params_info.outputs_num(); - for (size_t i = 0; i < output_num; i++) { - auto one_output_name_type = params_info.outputs_name_type(i); - std::string name = one_output_name_type.first; - if (output_name_to_idx.count(name) == 0) { - MS_LOG(EXCEPTION) << cnode_name << " does not have output named " << name; - return false; - } - size_t output_idx = output_name_to_idx.at(name); - TypeId kernel_node_output_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_idx); - TypeId registered_output_type = one_output_name_type.second; - if (registered_output_type != kernel_node_output_type) { - return false; - } - } - return true; -} +bool OptimizerKernelFactory::Matched(const ParamsInfo &, const CNodePtr &) { return true; } } // namespace kernel } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc index d9ff6359d7e..70f14ec55a4 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "schema/cipher_generated.h" namespace mindspore { @@ -29,16 +30,50 @@ void ClientListKernel::InitKernel(size_t) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); } - - executor_ = &Executor::GetInstance(); - MS_EXCEPTION_IF_NULL(executor_); - if (!executor_->initialized()) { - MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; - return; - } cipher_init_ = &armour::CipherInit::GetInstance(); } +sigVerifyResult ClientListKernel::VerifySignature(const schema::GetClientList *get_clients_req) { + std::string fl_id = get_clients_req->fl_id()->str(); + std::string timestamp = get_clients_req->timestamp()->str(); + int iteration = get_clients_req->iteration(); + std::string iter_str = std::to_string(iteration); + auto fbs_signature = get_clients_req->signature(); + std::vector signature; + if (fbs_signature == nullptr) { + MS_LOG(ERROR) << "signature in get_clients_req is nullptr"; + return sigVerifyResult::FAILED; + } + signature.assign(fbs_signature->begin(), fbs_signature->end()); + std::map key_attestations; + const fl::PBMetadata &key_attestations_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation); + const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation(); + auto iter = key_attestation_pb.key_attestations().begin(); + for (; iter != key_attestation_pb.key_attestations().end(); ++iter) { + (void)key_attestations.emplace(std::pair(iter->first, iter->second)); + } + if (key_attestations.find(fl_id) == key_attestations.end()) { + MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id; + return sigVerifyResult::FAILED; + } + + std::vector src_data; + (void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end()); + (void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end()); + mindspore::ps::server::CertVerify certVerify; + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH); + if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) { + return sigVerifyResult::FAILED; + } + if (!certVerify.verifyTimeStamp(fl_id, timestamp)) { + return sigVerifyResult::TIMEOUT; + } + MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success."; + return sigVerifyResult::PASSED; +} + bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req, const std::shared_ptr &fbb) { std::vector client_list; @@ -48,7 +83,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient if (!LocalMetaStore::GetInstance().has_value(kCtxUpdateModelThld)) { MS_LOG(ERROR) << "update_model_client_threshold is not set."; BuildClientListRsp(fbb, schema::ResponseCode_SystemError, "update_model_client_threshold is not set.", - empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + empty_client_list, std::to_string(CURRENT_TIME_MILLI.count()), iter_num); return false; } uint64_t update_model_client_needed = LocalMetaStore::GetInstance().value(kCtxUpdateModelThld); @@ -57,11 +92,11 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient for (size_t i = 0; i < IntToSize(client_list_pb.fl_id_size()); ++i) { client_list.push_back(client_list_pb.fl_id(SizeToInt(i))); } - if (static_cast(client_list.size()) < update_model_client_needed) { + if (client_list.size() < update_model_client_needed) { MS_LOG(INFO) << "The server is not ready. update_model_client_needed: " << update_model_client_needed; MS_LOG(INFO) << "now update_model_client_num: " << client_list_pb.fl_id_size(); BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", empty_client_list, - std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); return false; } @@ -69,7 +104,7 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient std::string reason = "fl_id: " + fl_id + " is not in the update_model_clients"; MS_LOG(INFO) << reason; BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, empty_client_list, - std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); return false; } @@ -79,34 +114,27 @@ bool ClientListKernel::DealClient(const size_t iter_num, const schema::GetClient std::string reason = "update get update model clients failed"; MS_LOG(ERROR) << reason; BuildClientListRsp(fbb, schema::ResponseCode_SucNotReady, reason, empty_client_list, - std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); return false; } + if (!DistributedCountService::GetInstance().Count(name_, get_clients_req->fl_id()->str())) { std::string reason = "Counting for get user list request failed. Please retry later."; BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, empty_client_list, - std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); MS_LOG(ERROR) << reason; return false; } - MS_LOG(INFO) << "send clients_list succeed!"; - MS_LOG(INFO) << "UpdateModel client list: "; - for (size_t i = 0; i < client_list.size(); ++i) { - MS_LOG(INFO) << " fl_id : " << client_list[i]; - } MS_LOG(INFO) << "update_model_client_needed: " << update_model_client_needed; BuildClientListRsp(fbb, schema::ResponseCode_SUCCEED, "send clients_list succeed!", client_list, - std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); return true; } bool ClientListKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); - size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); - MS_LOG(INFO) << "Iteration number is " << iter_num << ", ClientListKernel total duration is " << total_duration; - clock_t start_time = clock(); - + MS_LOG(INFO) << "Launching ClientListKernel, Iteration number is " << iter_num; if (inputs.size() != 1 || outputs.size() != 1) { std::string reason = "inputs or outputs size is invalid."; MS_LOG(ERROR) << reason; @@ -121,26 +149,66 @@ bool ClientListKernel::Launch(const std::vector &inputs, const std:: return false; } std::vector client_list; + flatbuffers::Verifier verifier(reinterpret_cast(req_data), inputs[0]->size); + if (!verifier.VerifyBuffer()) { + std::string reason = "The schema of GetClientList is invalid."; + BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } const schema::GetClientList *get_clients_req = flatbuffers::GetRoot(req_data); + if (get_clients_req == nullptr) { + std::string reason = "Building flatbuffers schema failed for GetClientList."; + BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + // verify signature + if (ps::PSContext::instance()->pki_verify()) { + sigVerifyResult verify_result = VerifySignature(get_clients_req); + if (verify_result == sigVerifyResult::FAILED) { + std::string reason = "verify signature failed."; + BuildClientListRsp(fbb, schema::ResponseCode_RequestError, reason, client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + if (verify_result == sigVerifyResult::TIMEOUT) { + std::string reason = "verify signature timestamp failed."; + BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, reason, client_list, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + if (verify_result == sigVerifyResult::PASSED) { + MS_LOG(INFO) << "verify signature passed!"; + } + } + size_t iter_client = IntToSize(get_clients_req->iteration()); if (iter_num != iter_client) { MS_LOG(ERROR) << "client list iteration number is invalid: server now iteration is " << iter_num << ". client request iteration is " << iter_client; BuildClientListRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", client_list, - std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { - MS_LOG(ERROR) << "Current amount for GetClientList is enough."; + MS_LOG(WARNING) << "Current amount for GetClientList is enough."; } - (void)DealClient(iter_num, get_clients_req, fbb); + if (!DealClient(iter_num, get_clients_req, fbb)) { + MS_LOG(WARNING) << "Get Client List not ready."; + } GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - clock_t end_time = clock(); - double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); - MS_LOG(INFO) << "client_list_kernel success time is : " << duration; return true; } // namespace fl @@ -153,28 +221,28 @@ bool ClientListKernel::Reset() { return true; } -void ClientListKernel::BuildClientListRsp(const std::shared_ptr &client_list_resp_builder, +void ClientListKernel::BuildClientListRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const string &reason, std::vector clients, const string &next_req_time, - const int iteration) { - auto rsp_reason = client_list_resp_builder->CreateString(reason); - auto rsp_next_req_time = client_list_resp_builder->CreateString(next_req_time); + const size_t iteration) { + auto rsp_reason = fbb->CreateString(reason); + auto rsp_next_req_time = fbb->CreateString(next_req_time); std::vector> clients_vector; for (auto client : clients) { - auto client_fb = client_list_resp_builder->CreateString(client); + auto client_fb = fbb->CreateString(client); clients_vector.push_back(client_fb); MS_LOG(WARNING) << "update client list: "; MS_LOG(WARNING) << client; } - auto clients_fb = client_list_resp_builder->CreateVector(clients_vector); - schema::ReturnClientListBuilder rsp_builder(*(client_list_resp_builder.get())); - rsp_builder.add_retcode(retcode); + auto clients_fb = fbb->CreateVector(clients_vector); + schema::ReturnClientListBuilder rsp_builder(*(fbb.get())); + rsp_builder.add_retcode(SizeToInt(retcode)); rsp_builder.add_reason(rsp_reason); rsp_builder.add_clients(clients_fb); - rsp_builder.add_iteration(iteration); + rsp_builder.add_iteration(SizeToInt(iteration)); rsp_builder.add_next_req_time(rsp_next_req_time); auto rsp_exchange_keys = rsp_builder.Finish(); - client_list_resp_builder->Finish(rsp_exchange_keys); + fbb->Finish(rsp_exchange_keys); return; } diff --git a/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.h index 6f13eb279dd..40f9b946069 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/client_list_kernel.h @@ -29,6 +29,9 @@ namespace mindspore { namespace fl { namespace server { namespace kernel { +// results of signature verification +enum sigVerifyResult { FAILED, TIMEOUT, PASSED }; + class ClientListKernel : public RoundKernel { public: ClientListKernel() = default; @@ -37,12 +40,13 @@ class ClientListKernel : public RoundKernel { bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) override; bool Reset() override; - void BuildClientListRsp(const std::shared_ptr &client_list_resp_builder, - const schema::ResponseCode retcode, const string &reason, std::vector clients, - const string &next_req_time, const int iteration); + void BuildClientListRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const string &reason, std::vector clients, const string &next_req_time, + const size_t iteration); private: armour::CipherInit *cipher_init_; + sigVerifyResult VerifySignature(const schema::GetClientList *get_clients_req); bool DealClient(const size_t iter_num, const schema::GetClientList *get_clients_req, const std::shared_ptr &fbb); Executor *executor_; diff --git a/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc index 4193229e2f6..687d322e0b0 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.cc @@ -18,6 +18,7 @@ #include #include #include +#include namespace mindspore { namespace fl { @@ -27,24 +28,15 @@ void ExchangeKeysKernel::InitKernel(size_t) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); } - - executor_ = &Executor::GetInstance(); - MS_EXCEPTION_IF_NULL(executor_); - if (!executor_->initialized()) { - MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; - return; - } - cipher_key_ = &armour::CipherKeys::GetInstance(); } -bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr &fbb, const int iter_num) { +bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr &fbb, const size_t iter_num) { if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { std::string reason = "Current amount for exchangeKey is enough. Please retry later."; cipher_key_->BuildExchangeKeysRsp( fbb, schema::ResponseCode_OutOfTime, reason, - std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp)), - IntToSize(iter_num)); + std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp)), iter_num); MS_LOG(WARNING) << reason; return true; } @@ -53,30 +45,84 @@ bool ExchangeKeysKernel::ReachThresholdForExchangeKeys(const std::shared_ptr &fbb, const schema::RequestExchangeKeys *exchange_keys_req, - const int iter_num) { + const size_t iter_num) { MS_ERROR_IF_NULL_W_RET_VAL(exchange_keys_req, false); if (!DistributedCountService::GetInstance().Count(name_, exchange_keys_req->fl_id()->str())) { std::string reason = "Counting for exchange kernel request failed. Please retry later."; cipher_key_->BuildExchangeKeysRsp( fbb, schema::ResponseCode_OutOfTime, reason, - std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp)), - IntToSize(iter_num)); + std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp)), iter_num); MS_LOG(ERROR) << reason; return false; } return true; } +sigVerifyResult ExchangeKeysKernel::VerifySignature(const schema::RequestExchangeKeys *exchange_keys_req) { + std::string fl_id = exchange_keys_req->fl_id()->str(); + std::string timestamp = exchange_keys_req->timestamp()->str(); + int iteration = exchange_keys_req->iteration(); + std::string iter_str = std::to_string(iteration); + auto fbs_signature = exchange_keys_req->signature(); + std::vector signature; + if (fbs_signature == nullptr) { + MS_LOG(ERROR) << "signature in exchange_keys_req is nullptr"; + return sigVerifyResult::FAILED; + } + signature.assign(fbs_signature->begin(), fbs_signature->end()); + std::map key_attestations; + const fl::PBMetadata &key_attestations_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation); + const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation(); + auto iter = key_attestation_pb.key_attestations().begin(); + for (; iter != key_attestation_pb.key_attestations().end(); ++iter) { + (void)key_attestations.emplace(std::pair(iter->first, iter->second)); + } + if (key_attestations.find(fl_id) == key_attestations.end()) { + MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id; + return sigVerifyResult::FAILED; + } + + auto fbs_cpk = exchange_keys_req->c_pk(); + auto fbs_spk = exchange_keys_req->s_pk(); + if (fbs_cpk == nullptr || fbs_spk == nullptr) { + MS_LOG(ERROR) << "public key from exchange_keys_req is null"; + return sigVerifyResult::FAILED; + } + size_t spk_len = fbs_spk->size(); + size_t cpk_len = fbs_cpk->size(); + std::vector cpk(cpk_len); + std::vector spk(spk_len); + bool ret_create_code_cpk = mindspore::armour::CreateArray(&cpk, *fbs_cpk); + bool ret_create_code_spk = mindspore::armour::CreateArray(&spk, *fbs_spk); + if (!(ret_create_code_cpk && ret_create_code_spk)) { + MS_LOG(ERROR) << "create array for public keys failed"; + return sigVerifyResult::FAILED; + } + + std::vector src_data; + (void)src_data.insert(src_data.end(), cpk.begin(), cpk.end()); + (void)src_data.insert(src_data.end(), spk.begin(), spk.end()); + (void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end()); + (void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end()); + mindspore::ps::server::CertVerify certVerify; + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH); + if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) { + return sigVerifyResult::FAILED; + } + if (!certVerify.verifyTimeStamp(fl_id, timestamp)) { + return sigVerifyResult::TIMEOUT; + } + MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success."; + return sigVerifyResult::PASSED; +} + bool ExchangeKeysKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { - MS_LOG(INFO) << "Launching ExchangeKey kernel."; - bool response = false; size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); - size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); - MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ExchangeKeysKernel allowed Duration Is " - << total_duration; - clock_t start_time = clock(); - + MS_LOG(INFO) << "Launching ExchangeKey kernel, ITERATION NUMBER IS : " << iter_num; + bool response = false; if (inputs.size() != 1 || outputs.size() != 1) { std::string reason = "inputs or outputs size is invalid."; MS_LOG(ERROR) << reason; @@ -91,11 +137,56 @@ bool ExchangeKeysKernel::Launch(const std::vector &inputs, const std return false; } - if (ReachThresholdForExchangeKeys(fbb, SizeToInt(iter_num))) { + if (ReachThresholdForExchangeKeys(fbb, iter_num)) { + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + flatbuffers::Verifier verifier(reinterpret_cast(req_data), inputs[0]->size); + if (!verifier.VerifyBuffer()) { + std::string reason = "The schema of RequestExchangeKeys is invalid."; + cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } const schema::RequestExchangeKeys *exchange_keys_req = flatbuffers::GetRoot(req_data); + if (exchange_keys_req == nullptr) { + std::string reason = "Building flatbuffers schema failed for ExchangeKeys."; + cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + // verify signature + if (ps::PSContext::instance()->pki_verify()) { + sigVerifyResult verify_result = VerifySignature(exchange_keys_req); + if (verify_result == sigVerifyResult::FAILED) { + std::string reason = "verify signature failed."; + cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::TIMEOUT) { + std::string reason = "verify signature timestamp failed."; + cipher_key_->BuildExchangeKeysRsp(fbb, schema::ResponseCode_OutOfTime, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::PASSED) { + MS_LOG(INFO) << "verify signature passed!"; + } + } + size_t iter_client = IntToSize(exchange_keys_req->iteration()); if (iter_num != iter_client) { MS_LOG(ERROR) << "ExchangeKeys iteration number is invalid: server now iteration is " << iter_num @@ -107,19 +198,16 @@ bool ExchangeKeysKernel::Launch(const std::vector &inputs, const std } response = cipher_key_->ExchangeKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), exchange_keys_req, fbb); if (!response) { - MS_LOG(WARNING) << "update exchange keys is failed."; + MS_LOG(ERROR) << "update exchange keys is failed."; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } - if (!CountForExchangeKeys(fbb, exchange_keys_req, SizeToInt(iter_num))) { + if (!CountForExchangeKeys(fbb, exchange_keys_req, iter_num)) { MS_LOG(ERROR) << "count for exchange keys failed."; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - clock_t end_time = clock(); - double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); - MS_LOG(INFO) << "ExchangeKeysKernel DURATION TIME IS : " << duration; return true; } diff --git a/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.h index 3783f0df18a..18bf2fddea8 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/exchange_keys_kernel.h @@ -25,11 +25,15 @@ #include "fl/server/kernel/round/round_kernel_factory.h" #include "fl/server/executor.h" #include "fl/armour/cipher/cipher_keys.h" +#include "fl/armour/cipher/cipher_meta_storage.h" namespace mindspore { namespace fl { namespace server { namespace kernel { +// results of signature verification +enum sigVerifyResult { FAILED, TIMEOUT, PASSED }; + class ExchangeKeysKernel : public RoundKernel { public: ExchangeKeysKernel() = default; @@ -43,9 +47,10 @@ class ExchangeKeysKernel : public RoundKernel { Executor *executor_; size_t iteration_time_window_; armour::CipherKeys *cipher_key_; - bool ReachThresholdForExchangeKeys(const std::shared_ptr &fbb, const int iter_num); + sigVerifyResult VerifySignature(const schema::RequestExchangeKeys *exchange_keys_req); + bool ReachThresholdForExchangeKeys(const std::shared_ptr &fbb, const size_t iter_num); bool CountForExchangeKeys(const std::shared_ptr &fbb, const schema::RequestExchangeKeys *exchange_keys_req, - const int iter_num); + const size_t iter_num); }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc index 2a7393f43d4..9b818fa1182 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.cc @@ -17,6 +17,8 @@ #include "fl/server/kernel/round/get_keys_kernel.h" #include #include +#include +#include namespace mindspore { namespace fl { @@ -26,24 +28,16 @@ void GetKeysKernel::InitKernel(size_t) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); } - - executor_ = &Executor::GetInstance(); - MS_EXCEPTION_IF_NULL(executor_); - if (!executor_->initialized()) { - MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; - return; - } - cipher_key_ = &armour::CipherKeys::GetInstance(); } bool GetKeysKernel::CountForGetKeys(const std::shared_ptr &fbb, const schema::GetExchangeKeys *get_keys_req, - const int iter_num) { + const size_t iter_num) { MS_ERROR_IF_NULL_W_RET_VAL(get_keys_req, false); if (!DistributedCountService::GetInstance().Count(name_, get_keys_req->fl_id()->str())) { std::string reason = "Counting for getkeys kernel request failed. Please retry later."; cipher_key_->BuildGetKeysRsp( - fbb, schema::ResponseCode_OutOfTime, IntToSize(iter_num), + fbb, schema::ResponseCode_OutOfTime, iter_num, std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp)), false); MS_LOG(ERROR) << reason; return false; @@ -51,16 +45,52 @@ bool GetKeysKernel::CountForGetKeys(const std::shared_ptr &fbb, const return true; } +sigVerifyResult GetKeysKernel::VerifySignature(const schema::GetExchangeKeys *get_keys_req) { + std::string fl_id = get_keys_req->fl_id()->str(); + std::string timestamp = get_keys_req->timestamp()->str(); + int iteration = get_keys_req->iteration(); + std::string iter_str = std::to_string(iteration); + auto fbs_signature = get_keys_req->signature(); + std::vector signature; + if (fbs_signature == nullptr) { + MS_LOG(ERROR) << "signature in get_keys_req is nullptr"; + return sigVerifyResult::FAILED; + } + signature.assign(fbs_signature->begin(), fbs_signature->end()); + std::map key_attestations; + const fl::PBMetadata &key_attestations_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation); + const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation(); + auto iter = key_attestation_pb.key_attestations().begin(); + for (; iter != key_attestation_pb.key_attestations().end(); ++iter) { + (void)key_attestations.emplace(std::pair(iter->first, iter->second)); + } + if (key_attestations.find(fl_id) == key_attestations.end()) { + MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id; + return sigVerifyResult::FAILED; + } + + std::vector src_data; + (void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end()); + (void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end()); + mindspore::ps::server::CertVerify certVerify; + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH); + if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) { + return sigVerifyResult::FAILED; + } + if (!certVerify.verifyTimeStamp(fl_id, timestamp)) { + return sigVerifyResult::TIMEOUT; + } + MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success."; + return sigVerifyResult::PASSED; +} + bool GetKeysKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { - MS_LOG(INFO) << "Launching GetKeys kernel."; - bool response = false; size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); - size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); - MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetKeysKernel allowed Duration Is " - << total_duration; - clock_t start_time = clock(); - + MS_LOG(INFO) << "Launching GetKeys kernel, ITERATION NUMBER IS : " << iter_num; + bool response = false; if (inputs.size() != 1 || outputs.size() != 1) { std::string reason = "inputs or outputs size is invalid."; MS_LOG(ERROR) << reason; @@ -74,12 +104,54 @@ bool GetKeysKernel::Launch(const std::vector &inputs, const std::vec MS_LOG(ERROR) << reason; return false; } - if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { - MS_LOG(ERROR) << "Current amount for GetKeysKernel is enough."; + MS_LOG(WARNING) << "Current amount for GetKeysKernel is enough."; + } + flatbuffers::Verifier verifier(reinterpret_cast(req_data), inputs[0]->size); + if (!verifier.VerifyBuffer()) { + std::string reason = "The schema of GetExchangeKeys is invalid."; + cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num, + std::to_string(CURRENT_TIME_MILLI.count()), false); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot(req_data); + if (get_exchange_keys_req == nullptr) { + std::string reason = "Building flatbuffers schema failed for GetExchangeKeys."; + cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num, + std::to_string(CURRENT_TIME_MILLI.count()), false); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + // verify signature + if (ps::PSContext::instance()->pki_verify()) { + sigVerifyResult verify_result = VerifySignature(get_exchange_keys_req); + if (verify_result == sigVerifyResult::FAILED) { + std::string reason = "verify signature failed."; + cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_RequestError, iter_num, + std::to_string(CURRENT_TIME_MILLI.count()), false); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::TIMEOUT) { + std::string reason = "verify signature timestamp failed."; + cipher_key_->BuildGetKeysRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, + std::to_string(CURRENT_TIME_MILLI.count()), false); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::PASSED) { + MS_LOG(INFO) << "verify signature passed!"; + } } - const schema::GetExchangeKeys *get_exchange_keys_req = flatbuffers::GetRoot(req_data); size_t iter_client = IntToSize(get_exchange_keys_req->iteration()); if (iter_num != iter_client) { MS_LOG(ERROR) << "GetKeysKernel iteration invalid. server now iteration is " << iter_num @@ -91,18 +163,15 @@ bool GetKeysKernel::Launch(const std::vector &inputs, const std::vec } response = cipher_key_->GetKeys(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_exchange_keys_req, fbb); if (!response) { - MS_LOG(WARNING) << "get public keys is failed."; + MS_LOG(WARNING) << "get public keys not ready."; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } - if (!CountForGetKeys(fbb, get_exchange_keys_req, SizeToInt(iter_num))) { + if (!CountForGetKeys(fbb, get_exchange_keys_req, iter_num)) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } GenerateOutput(outputs, fbb->GetCurrentBufferPointer(), fbb->GetSize()); - clock_t end_time = clock(); - double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); - MS_LOG(INFO) << "GetKeysKernel DURATION TIME IS : " << duration; return true; } diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.h index 6b5fd8187d8..3bc78ae9a91 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/get_keys_kernel.h @@ -30,6 +30,9 @@ namespace mindspore { namespace fl { namespace server { namespace kernel { +// results of signature verification +enum sigVerifyResult { FAILED, TIMEOUT, PASSED }; + class GetKeysKernel : public RoundKernel { public: GetKeysKernel() = default; @@ -43,8 +46,9 @@ class GetKeysKernel : public RoundKernel { Executor *executor_; size_t iteration_time_window_; armour::CipherKeys *cipher_key_; + sigVerifyResult VerifySignature(const schema::GetExchangeKeys *get_keys_req); bool CountForGetKeys(const std::shared_ptr &fbb, const schema::GetExchangeKeys *get_keys_req, - const int iter_num); + const size_t iter_num); }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.cc new file mode 100644 index 00000000000..489d5fd28cb --- /dev/null +++ b/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.cc @@ -0,0 +1,281 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fl/server/kernel/round/get_list_sign_kernel.h" +#include +#include +#include +#include +#include +#include "schema/cipher_generated.h" + +namespace mindspore { +namespace fl { +namespace server { +namespace kernel { +void GetListSignKernel::InitKernel(size_t) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + cipher_init_ = &armour::CipherInit::GetInstance(); +} + +sigVerifyResult GetListSignKernel::VerifySignature(const schema::RequestAllClientListSign *client_list_sign_req) { + std::string fl_id = client_list_sign_req->fl_id()->str(); + std::string timestamp = client_list_sign_req->timestamp()->str(); + int iteration = client_list_sign_req->iteration(); + std::string iter_str = std::to_string(iteration); + auto fbs_signature = client_list_sign_req->signature(); + std::vector signature; + if (fbs_signature == nullptr) { + MS_LOG(ERROR) << "signature in client_list_sign_req is nullptr"; + return sigVerifyResult::FAILED; + } + signature.assign(fbs_signature->begin(), fbs_signature->end()); + std::map key_attestations; + const fl::PBMetadata &key_attestations_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation); + const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation(); + auto iter = key_attestation_pb.key_attestations().begin(); + for (; iter != key_attestation_pb.key_attestations().end(); ++iter) { + (void)key_attestations.emplace(std::pair(iter->first, iter->second)); + } + if (key_attestations.find(fl_id) == key_attestations.end()) { + MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id; + return sigVerifyResult::FAILED; + } + + std::vector src_data; + (void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end()); + (void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end()); + mindspore::ps::server::CertVerify certVerify; + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH); + if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) { + return sigVerifyResult::FAILED; + } + if (!certVerify.verifyTimeStamp(fl_id, timestamp)) { + return sigVerifyResult::TIMEOUT; + } + MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success."; + return sigVerifyResult::PASSED; +} + +bool GetListSignKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); + MS_LOG(INFO) << "Launching GetListSign kernel, Iteration number is " << iter_num; + if (inputs.size() != 1 || outputs.size() != 1) { + std::string reason = "inputs or outputs size is invalid."; + MS_LOG(ERROR) << reason; + return false; + } + std::shared_ptr fbb = std::make_shared(); + void *req_data = inputs[0]->addr; + if (fbb == nullptr || req_data == nullptr) { + std::string reason = "FBBuilder builder or req_data is nullptr."; + MS_LOG(ERROR) << reason; + return false; + } + std::map> list_signs; + flatbuffers::Verifier verifier(reinterpret_cast(req_data), inputs[0]->size); + if (!verifier.VerifyBuffer()) { + std::string reason = "The schema of RequestAllClientListSign is invalid."; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + const schema::RequestAllClientListSign *get_list_sign_req = + flatbuffers::GetRoot(req_data); + if (get_list_sign_req == nullptr) { + std::string reason = "Building flatbuffers schema failed for RequestAllClientListSign."; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + // verify signature + if (ps::PSContext::instance()->pki_verify()) { + sigVerifyResult verify_result = VerifySignature(get_list_sign_req); + if (verify_result == sigVerifyResult::FAILED) { + std::string reason = "verify signature failed."; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::TIMEOUT) { + std::string reason = "verify signature timestamp failed."; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), + iter_num, list_signs); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::PASSED) { + MS_LOG(INFO) << "verify signature passed!"; + } + } + + size_t iter_client = IntToSize(get_list_sign_req->iteration()); + if (iter_num != iter_client) { + MS_LOG(ERROR) << "get list sign iteration number is invalid: server now iteration is " << iter_num + << ". client request iteration is " << iter_client; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, "iter num is error.", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num, list_signs); + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + std::string fl_id = get_list_sign_req->fl_id()->str(); + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + MS_LOG(WARNING) << "Current amount for GetListSignKernel is enough."; + } + if (!GetListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), get_list_sign_req, fbb)) { + MS_LOG(WARNING) << "get list signs not ready."; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + std::string count_reason = ""; + if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) { + std::string reason = "Counting for get list sign request failed. Please retry later. " + count_reason; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), + iter_num, list_signs); + MS_LOG(ERROR) << reason; + return true; + } + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; +} + +bool GetListSignKernel::GetListSign(const size_t cur_iterator, const std::string &next_req_time, + const schema::RequestAllClientListSign *get_list_sign_req, + const std::shared_ptr &fbb) { + MS_LOG(INFO) << "CipherMgr::SendClientListSign START"; + std::map> client_list_signs_empty; + std::map> client_list_signs_all; + const fl::PBMetadata &clients_sign_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientListSigns); + const fl::ClientListSign &clients_sign_pb = clients_sign_pb_out.client_list_sign(); + size_t cur_clients_sign_num = IntToSize(clients_sign_pb.client_list_sign_size()); + if (cur_clients_sign_num < cipher_init_->push_list_sign_threshold) { + MS_LOG(INFO) << "The server is not ready. push_list_sign_needed: " << cipher_init_->push_list_sign_threshold; + MS_LOG(INFO) << "now push_sign_client_num: " << clients_sign_pb.client_list_sign_size(); + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SucNotReady, "The server is not ready.", next_req_time, + cur_iterator, client_list_signs_empty); + return false; + } + + std::vector update_model_clients; + const PBMetadata update_model_clients_pb_out = + DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList); + const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list(); + for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) { + update_model_clients.push_back(update_model_clients_pb.fl_id(SizeToInt(i))); + } + + auto iter = clients_sign_pb.client_list_sign().begin(); + for (; iter != clients_sign_pb.client_list_sign().end(); ++iter) { + std::vector signature(iter->second.begin(), iter->second.end()); + (void)client_list_signs_all.emplace(std::pair>(iter->first, signature)); + } + + std::string fl_id = get_list_sign_req->fl_id()->str(); + if (client_list_signs_all.find(fl_id) == client_list_signs_all.end()) { + std::string reason; + if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) { + reason = "client not send list signature, but in update model client list."; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator, + client_list_signs_all); + } else { + reason = "client not send list signature, && client is illegal"; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator, + client_list_signs_empty); + } + MS_LOG(WARNING) << reason; + return false; + } + + if (client_list_signs_all.find(fl_id) != client_list_signs_all.end()) { + // the client has sended signature, return false. + std::string reason = "The server has received the request, please do not request again."; + MS_LOG(WARNING) << reason; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator, + client_list_signs_all); + return false; + } + + std::string reason = "send update model client list signature success. "; + BuildGetListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator, + client_list_signs_all); + MS_LOG(INFO) << "CipherMgr::Send Client ListSign Success"; + return true; +} + +bool GetListSignKernel::Reset() { + MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + MS_LOG(INFO) << "Get List Signature kernel reset!"; + DistributedCountService::GetInstance().ResetCounter(name_); + StopTimer(); + return true; +} + +void GetListSignKernel::BuildGetListSignKernelRsp(const std::shared_ptr &fbb, + const schema::ResponseCode retcode, const string &reason, + const string &next_req_time, const size_t iteration, + const std::map> &list_signs) { + auto rsp_reason = fbb->CreateString(reason); + auto rsp_next_req_time = fbb->CreateString(next_req_time); + if (list_signs.size() == 0) { + schema::ReturnAllClientListSignBuilder rsp_builder(*(fbb.get())); + rsp_builder.add_retcode(static_cast(retcode)); + rsp_builder.add_reason(rsp_reason); + rsp_builder.add_next_req_time(rsp_next_req_time); + rsp_builder.add_iteration(SizeToInt(iteration)); + auto rsp_get_list_sign = rsp_builder.Finish(); + fbb->Finish(rsp_get_list_sign); + return; + } + std::vector> client_list_signs; + for (auto iter = list_signs.begin(); iter != list_signs.end(); ++iter) { + auto fbs_fl_id = fbb->CreateString(iter->first); + auto fbs_sign = fbb->CreateVector(iter->second.data(), iter->second.size()); + auto cur_sign = schema::CreateClientListSign(*fbb, fbs_fl_id, fbs_sign); + client_list_signs.push_back(cur_sign); + } + auto all_signs = fbb->CreateVector(client_list_signs); + schema::ReturnAllClientListSignBuilder rsp_builder(*(fbb.get())); + rsp_builder.add_retcode(static_cast(retcode)); + rsp_builder.add_reason(rsp_reason); + rsp_builder.add_next_req_time(rsp_next_req_time); + rsp_builder.add_iteration(SizeToInt(iteration)); + rsp_builder.add_client_list_sign(all_signs); + auto rsp_get_list_sign = rsp_builder.Finish(); + fbb->Finish(rsp_get_list_sign); + return; +} + +REG_ROUND_KERNEL(getListSign, GetListSignKernel) +} // namespace kernel +} // namespace server +} // namespace fl +} // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.h new file mode 100644 index 00000000000..5758f29b7e1 --- /dev/null +++ b/mindspore/ccsrc/fl/server/kernel/round/get_list_sign_kernel.h @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_LIST_SIGN_KERNEL_H +#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_LIST_SIGN_KERNEL_H + +#include +#include +#include +#include +#include "fl/server/common.h" +#include "fl/server/kernel/round/round_kernel.h" +#include "fl/server/kernel/round/round_kernel_factory.h" +#include "fl/armour/cipher/cipher_init.h" +#include "fl/server/executor.h" +namespace mindspore { +namespace fl { +namespace server { +namespace kernel { +// results of signature verification +enum sigVerifyResult { FAILED, TIMEOUT, PASSED }; + +class GetListSignKernel : public RoundKernel { + public: + GetListSignKernel() = default; + ~GetListSignKernel() override = default; + void InitKernel(size_t required_cnt) override; + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override; + bool Reset() override; + void BuildGetListSignKernelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const string &reason, const string &next_req_time, const size_t iteration, + const std::map> &list_signs); + + private: + armour::CipherInit *cipher_init_; + Executor *executor_; + size_t iteration_time_window_; + sigVerifyResult VerifySignature(const schema::RequestAllClientListSign *client_list_sign_req); + bool GetListSign(const size_t cur_iterator, const std::string &next_req_time, + const schema::RequestAllClientListSign *client_list_sign_req, + const std::shared_ptr &fbb); +}; +} // namespace kernel +} // namespace server +} // namespace fl +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_GET_LIST_SIGN_KERNEL_H diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc index ff3cab4f60e..9752cd9bece 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.cc @@ -67,9 +67,9 @@ bool GetModelKernel::Launch(const std::vector &inputs, const std::ve return true; } - (void)++retry_count_; + ++retry_count_; if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) { - MS_LOG(INFO) << "Launching GetModelKernel retry count is " << retry_count_.load(); + MS_LOG(INFO) << "Launching GetModelKernel kernel. Retry count is " << retry_count_.load(); } const schema::RequestGetModel *get_model_req = flatbuffers::GetRoot(req_data); @@ -95,14 +95,14 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons auto next_req_time = LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp); std::map feature_maps; size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); - size_t get_model_iter = IntToSize(get_model_req->iteration()); + size_t get_model_iter = static_cast(get_model_req->iteration()); const auto &iter_to_model = ModelStore::GetInstance().iteration_to_model(); size_t latest_iter_num = iter_to_model.rbegin()->first; // If this iteration is not finished yet, return ResponseCode_SucNotReady so that clients could get model later. - if ((current_iter == get_model_iter && latest_iter_num != current_iter)) { + if (current_iter == get_model_iter && latest_iter_num != current_iter) { std::string reason = "The model is not ready yet for iteration " + std::to_string(get_model_iter) + - ". Maybe this is because\n" + "1.Client doesn't send enough update model requests.\n" + - "2. Worker has not push all the weights to servers."; + ". Maybe this is because\n" + "1. Client doesn't not send enough update model request.\n" + + "2. Worker has not push weights to server."; BuildGetModelRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps, std::to_string(next_req_time)); if (retry_count_.load() % kPrintGetModelForEveryRetryTime == 1) { @@ -120,7 +120,7 @@ void GetModelKernel::GetModel(const schema::RequestGetModel *get_model_req, cons feature_maps = ModelStore::GetInstance().GetModelByIterNum(get_model_iter); } - MS_LOG(INFO) << "GetModel last iteration is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() + MS_LOG(INFO) << "GetModel last iteratin is valid or not: " << Iteration::GetInstance().is_last_iteration_valid() << ", next request time is " << next_req_time << ", current iteration is " << current_iter; BuildGetModelRsp(fbb, schema::ResponseCode_SUCCEED, "Get model for iteration " + std::to_string(get_model_iter), current_iter, feature_maps, std::to_string(next_req_time)); diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.h index d8b0947552f..bc71355533f 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/get_model_kernel.h @@ -51,7 +51,7 @@ class GetModelKernel : public RoundKernel { Executor *executor_; // The time window of one iteration. - size_t iteration_time_window_; + size_t iteration_time_window_{0}; // The count of retrying because the iteration is not finished. std::atomic retry_count_; diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc index 28143c31bd4..1bc038475b4 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.cc @@ -30,23 +30,15 @@ void GetSecretsKernel::InitKernel(size_t) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); } - - executor_ = &Executor::GetInstance(); - MS_EXCEPTION_IF_NULL(executor_); - if (!executor_->initialized()) { - MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; - return; - } - cipher_share_ = &armour::CipherShares::GetInstance(); } bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr &fbb, - const schema::GetShareSecrets *get_secrets_req, const int iter_num) { + const schema::GetShareSecrets *get_secrets_req, const size_t iter_num) { MS_ERROR_IF_NULL_W_RET_VAL(get_secrets_req, false); if (!DistributedCountService::GetInstance().Count(name_, get_secrets_req->fl_id()->str())) { std::string reason = "Counting for get secrets kernel request failed. Please retry later."; - cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, IntToSize(iter_num), + cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, std::to_string(CURRENT_TIME_MILLI.count()), nullptr); MS_LOG(ERROR) << reason; return false; @@ -54,16 +46,52 @@ bool GetSecretsKernel::CountForGetSecrets(const std::shared_ptr &fbb, return true; } +sigVerifyResult GetSecretsKernel::VerifySignature(const schema::GetShareSecrets *get_secrets_req) { + std::string fl_id = get_secrets_req->fl_id()->str(); + std::string timestamp = get_secrets_req->timestamp()->str(); + int iteration = get_secrets_req->iteration(); + std::string iter_str = std::to_string(iteration); + auto fbs_signature = get_secrets_req->signature(); + std::vector signature; + if (fbs_signature == nullptr) { + MS_LOG(ERROR) << "signature in get_secrets_req is nullptr"; + return sigVerifyResult::FAILED; + } + signature.assign(fbs_signature->begin(), fbs_signature->end()); + std::map key_attestations; + const fl::PBMetadata &key_attestations_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation); + const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation(); + auto iter = key_attestation_pb.key_attestations().begin(); + for (; iter != key_attestation_pb.key_attestations().end(); ++iter) { + (void)key_attestations.emplace(std::pair(iter->first, iter->second)); + } + if (key_attestations.find(fl_id) == key_attestations.end()) { + MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id; + return sigVerifyResult::FAILED; + } + + std::vector src_data; + (void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end()); + (void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end()); + mindspore::ps::server::CertVerify certVerify; + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH); + if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) { + return sigVerifyResult::FAILED; + } + if (!certVerify.verifyTimeStamp(fl_id, timestamp)) { + return sigVerifyResult::TIMEOUT; + } + MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success."; + return sigVerifyResult::PASSED; +} + bool GetSecretsKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { - bool response = false; size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); - MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); std::string next_timestamp = std::to_string(CURRENT_TIME_MILLI.count()); - size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); - MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total GetSecretsKernel allowed Duration Is " - << total_duration; - clock_t start_time = clock(); + MS_LOG(INFO) << "Launching get secrets kernel, ITERATION NUMBER IS : " << iter_num; if (inputs.size() != 1 || outputs.size() != 1) { std::string reason = "inputs or outputs size is invalid."; @@ -78,8 +106,46 @@ bool GetSecretsKernel::Launch(const std::vector &inputs, const std:: MS_LOG(ERROR) << reason; return false; } - + flatbuffers::Verifier verifier(reinterpret_cast(req_data), inputs[0]->size); + if (!verifier.VerifyBuffer()) { + std::string reason = "The schema of GetShareSecrets is invalid."; + cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } const schema::GetShareSecrets *get_secrets_req = flatbuffers::GetRoot(req_data); + if (get_secrets_req == nullptr) { + std::string reason = "Building flatbuffers schema failed for GetExchangeKeys."; + cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + // verify signature + if (ps::PSContext::instance()->pki_verify()) { + sigVerifyResult verify_result = VerifySignature(get_secrets_req); + if (verify_result == sigVerifyResult::FAILED) { + std::string reason = "verify signature failed."; + cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_RequestError, iter_num, next_timestamp, nullptr); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::TIMEOUT) { + std::string reason = "verify signature timestamp failed."; + cipher_share_->BuildGetSecretsRsp(fbb, schema::ResponseCode_OutOfTime, iter_num, next_timestamp, nullptr); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::PASSED) { + MS_LOG(INFO) << "verify signature passed!"; + } + } size_t iter_client = IntToSize(get_secrets_req->iteration()); if (iter_num != iter_client) { MS_LOG(ERROR) << "GetSecretsKernel iteration invalid. server now iteration is " << iter_num @@ -93,20 +159,17 @@ bool GetSecretsKernel::Launch(const std::vector &inputs, const std:: MS_LOG(ERROR) << "Current amount for GetSecretsKernel is enough."; } - response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp); + bool response = cipher_share_->GetSecrets(get_secrets_req, fbb, next_timestamp); if (!response) { - MS_LOG(WARNING) << "get secret shares is failed."; + MS_LOG(WARNING) << "get secret shares not ready."; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } - if (!CountForGetSecrets(fbb, get_secrets_req, SizeToInt(iter_num))) { + if (!CountForGetSecrets(fbb, get_secrets_req, iter_num)) { GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - clock_t end_time = clock(); - double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); - MS_LOG(INFO) << "GetSecretsKernel DURATION TIME is : " << duration; return true; } diff --git a/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.h index 22156f51e2e..378223f71d0 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/get_secrets_kernel.h @@ -29,6 +29,9 @@ namespace mindspore { namespace fl { namespace server { namespace kernel { +// results of signature verification +enum sigVerifyResult { FAILED, TIMEOUT, PASSED }; + class GetSecretsKernel : public RoundKernel { public: GetSecretsKernel() = default; @@ -42,8 +45,9 @@ class GetSecretsKernel : public RoundKernel { Executor *executor_; size_t iteration_time_window_; armour::CipherShares *cipher_share_; + sigVerifyResult VerifySignature(const schema::GetShareSecrets *get_secrets_req); bool CountForGetSecrets(const std::shared_ptr &fbb, const schema::GetShareSecrets *get_secrets_req, - const int iter_num); + const size_t iter_num); }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc index d7c7a32249b..3b3bb6df103 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/pull_weight_kernel.cc @@ -37,6 +37,13 @@ void PullWeightKernel::InitKernel(size_t) { bool PullWeightKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { MS_LOG(DEBUG) << "Launching PullWeightKernel kernel."; + if (inputs.size() != 1 || outputs.size() != 1) { + std::string reason = "inputs or outputs size is invalid."; + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, reason.c_str(), reason.size()); + return true; + } + void *req_data = inputs[0]->addr; std::shared_ptr fbb = std::make_shared(); if (fbb == nullptr || req_data == nullptr) { @@ -71,7 +78,7 @@ void PullWeightKernel::PullWeight(const std::shared_ptr &fbb, } std::map feature_maps = {}; size_t current_iter = LocalMetaStore::GetInstance().curr_iter_num(); - size_t pull_weight_iter = IntToSize(pull_weight_req->iteration()); + size_t pull_weight_iter = static_cast(pull_weight_req->iteration()); // The iteration from worker should be the same as server's, otherwise return SucNotReady so that worker could retry. if (pull_weight_iter != current_iter) { std::string reason = "PullWeight iteration " + std::to_string(pull_weight_iter) + @@ -91,7 +98,7 @@ void PullWeightKernel::PullWeight(const std::shared_ptr &fbb, weight_names.push_back(weights_names_fbs->Get(i)->str()); } if (!executor_->IsWeightAggrDone(weight_names) || !executor_->unmasked()) { - (void)++retry_count_; + ++retry_count_; std::string reason = "The aggregation for the weights is not done yet."; BuildPullWeightRsp(fbb, schema::ResponseCode_SucNotReady, reason, current_iter, feature_maps); if (retry_count_.load() % kPrintPullWeightForEveryRetryTime == 1) { @@ -134,7 +141,7 @@ void PullWeightKernel::BuildPullWeightRsp(const std::shared_ptr &fbb, auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); schema::ResponsePullWeightBuilder rsp_pull_weight_builder(*(fbb.get())); - rsp_pull_weight_builder.add_retcode(static_cast(retcode)); + rsp_pull_weight_builder.add_retcode(SizeToInt(retcode)); rsp_pull_weight_builder.add_reason(fbs_reason); rsp_pull_weight_builder.add_iteration(SizeToInt(iteration)); rsp_pull_weight_builder.add_feature_map(fbs_feature_maps_vector); diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.cc new file mode 100644 index 00000000000..53fe0540f48 --- /dev/null +++ b/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.cc @@ -0,0 +1,277 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fl/server/kernel/round/push_list_sign_kernel.h" +#include +#include +#include +#include +#include +#include "schema/cipher_generated.h" + +namespace mindspore { +namespace fl { +namespace server { +namespace kernel { +void PushListSignKernel::InitKernel(size_t) { + if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { + iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); + } + cipher_init_ = &armour::CipherInit::GetInstance(); +} + +bool PushListSignKernel::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); + MS_LOG(INFO) << "Launching PushListSignKernel, Iteration number is " << iter_num; + if (inputs.size() != 1 || outputs.size() != 1) { + std::string reason = "inputs or outputs size is invalid."; + MS_LOG(ERROR) << reason; + return false; + } + std::shared_ptr fbb = std::make_shared(); + void *req_data = inputs[0]->addr; + if (fbb == nullptr || req_data == nullptr) { + std::string reason = "FBBuilder builder or req_data is nullptr."; + MS_LOG(ERROR) << reason; + return false; + } + flatbuffers::Verifier verifier(reinterpret_cast(req_data), inputs[0]->size); + if (!verifier.VerifyBuffer()) { + std::string reason = "The schema of PushClientListSign is invalid."; + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + const schema::SendClientListSign *client_list_sign_req = flatbuffers::GetRoot(req_data); + if (client_list_sign_req == nullptr) { + std::string reason = "Building flatbuffers schema failed for PushClientListSign."; + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + // verify signature + if (ps::PSContext::instance()->pki_verify()) { + sigVerifyResult verify_result = VerifySignature(client_list_sign_req); + if (verify_result == sigVerifyResult::FAILED) { + std::string reason = "verify signature failed."; + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + if (verify_result == sigVerifyResult::TIMEOUT) { + std::string reason = "verify signature timestamp failed."; + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + if (verify_result == sigVerifyResult::PASSED) { + MS_LOG(INFO) << "verify signature passed!"; + } + } + return LaunchForPushListSign(client_list_sign_req, iter_num, fbb, outputs); +} + +bool PushListSignKernel::LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req, + const size_t &iter_num, const std::shared_ptr &fbb, + const std::vector &outputs) { + size_t iter_client = IntToSize(client_list_sign_req->iteration()); + if (iter_num != iter_client) { + std::string reason = "push list sign iteration number is invalid"; + MS_LOG(WARNING) << reason; + MS_LOG(WARNING) << "server now iteration is " << iter_num << ". client request iteration is " << iter_client; + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), + iter_num); + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + std::vector update_model_clients; + const PBMetadata update_model_clients_pb_out = + DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList); + const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list(); + for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) { + update_model_clients.push_back(update_model_clients_pb.fl_id(i)); + } + std::string fl_id = client_list_sign_req->fl_id()->str(); + if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { + MS_LOG(ERROR) << "Current amount for PushListSignKernel is enough."; + if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) { + // client in get update model client list. + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, "Current amount for PushListSignKernel is enough.", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } else { + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, + "Current amount for PushListSignKernel is enough.", + std::to_string(CURRENT_TIME_MILLI.count()), iter_num); + } + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + if (!PushListSign(iter_num, std::to_string(CURRENT_TIME_MILLI.count()), client_list_sign_req, fbb, + update_model_clients)) { + MS_LOG(ERROR) << "push client list sign failed."; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + std::string count_reason = ""; + if (!DistributedCountService::GetInstance().Count(name_, fl_id, &count_reason)) { + std::string reason = "Counting for push list sign request failed. Please retry later. " + count_reason; + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(CURRENT_TIME_MILLI.count()), + iter_num); + MS_LOG(ERROR) << reason; + return true; + } + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; +} + +sigVerifyResult PushListSignKernel::VerifySignature(const schema::SendClientListSign *client_list_sign_req) { + std::string fl_id = client_list_sign_req->fl_id()->str(); + std::string timestamp = client_list_sign_req->timestamp()->str(); + int iteration = client_list_sign_req->iteration(); + std::string iter_str = std::to_string(iteration); + auto fbs_signature = client_list_sign_req->req_signature(); + std::vector signature; + if (fbs_signature == nullptr) { + MS_LOG(ERROR) << "signature in client_list_sign_req is nullptr"; + return sigVerifyResult::FAILED; + } + signature.assign(fbs_signature->begin(), fbs_signature->end()); + std::map key_attestations; + const fl::PBMetadata &key_attestations_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation); + const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation(); + auto iter = key_attestation_pb.key_attestations().begin(); + for (; iter != key_attestation_pb.key_attestations().end(); ++iter) { + (void)key_attestations.emplace(std::pair(iter->first, iter->second)); + } + if (key_attestations.find(fl_id) == key_attestations.end()) { + MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id; + return sigVerifyResult::FAILED; + } + + std::vector src_data; + (void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end()); + (void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end()); + mindspore::ps::server::CertVerify certVerify; + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH); + if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) { + return sigVerifyResult::FAILED; + } + if (!certVerify.verifyTimeStamp(fl_id, timestamp)) { + return sigVerifyResult::TIMEOUT; + } + MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success."; + return sigVerifyResult::PASSED; +} + +bool PushListSignKernel::PushListSign(const size_t cur_iterator, const std::string &next_req_time, + const schema::SendClientListSign *client_list_sign_req, + const std::shared_ptr &fbb, + const std::vector &update_model_clients) { + MS_LOG(INFO) << "CipherMgr::PushClientListSign START"; + std::vector get_client_list; // the clients which get update model client list + cipher_init_->cipher_meta_storage_.GetClientListFromServer(fl::server::kCtxGetUpdateModelClientList, + &get_client_list); + std::string fl_id = client_list_sign_req->fl_id()->str(); + if (find(get_client_list.begin(), get_client_list.end(), fl_id) == get_client_list.end()) { + // client not in get update model client list. + std::string reason = "client send signature is not in get update model client list. && client is illegal"; + MS_LOG(WARNING) << reason; + if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) { + // client in update model client list, client can move to next round + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator); + } else { + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator); + } + return false; + } + std::vector send_signs_clients; + const fl::PBMetadata &clients_sign_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientListSigns); + const fl::ClientListSign &clients_sign_pb = clients_sign_pb_out.client_list_sign(); + auto iter = clients_sign_pb.client_list_sign().begin(); + for (; iter != clients_sign_pb.client_list_sign().end(); ++iter) { + send_signs_clients.push_back(iter->first); + } + if (find(send_signs_clients.begin(), send_signs_clients.end(), fl_id) != send_signs_clients.end()) { + // the client has sended signature, return false. + std::string reason = "The server has received the request, please do not request again."; + MS_LOG(ERROR) << reason; + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator); + return false; + } + auto fbs_signature = client_list_sign_req->signature(); + std::vector signature; + if (fbs_signature != nullptr) { + signature.assign(fbs_signature->begin(), fbs_signature->end()); + } + fl::PairClientListSign pair_client_list_sign_pb; + pair_client_list_sign_pb.set_fl_id(fl_id); + pair_client_list_sign_pb.set_signature(signature.data(), signature.size()); + fl::PBMetadata pb_data; + pb_data.mutable_pair_client_list_sign()->MergeFrom(pair_client_list_sign_pb); + bool retcode = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxClientListSigns, pb_data); + if (!retcode) { + std::string reason = "store client list signature failed"; + MS_LOG(ERROR) << reason; + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_OutOfTime, reason, next_req_time, cur_iterator); + return false; + } + std::string reason = "send update model client list signature success. "; + BuildPushListSignKernelRsp(fbb, schema::ResponseCode_SUCCEED, reason, next_req_time, cur_iterator); + MS_LOG(INFO) << "CipherMgr::PushClientListSign Success"; + return true; +} + +bool PushListSignKernel::Reset() { + MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); + MS_LOG(INFO) << "Push list sign kernel reset!"; + DistributedCountService::GetInstance().ResetCounter(name_); + DistributedMetadataStore::GetInstance().ResetMetadata(kCtxClientListSigns); + StopTimer(); + return true; +} + +void PushListSignKernel::BuildPushListSignKernelRsp(const std::shared_ptr &fbb, + const schema::ResponseCode retcode, const string &reason, + const string &next_req_time, const size_t iteration) { + auto rsp_reason = fbb->CreateString(reason); + auto rsp_next_req_time = fbb->CreateString(next_req_time); + schema::ResponseClientListSignBuilder rsp_builder(*(fbb.get())); + rsp_builder.add_retcode(static_cast(retcode)); + rsp_builder.add_reason(rsp_reason); + rsp_builder.add_next_req_time(rsp_next_req_time); + rsp_builder.add_iteration(SizeToInt(iteration)); + auto rsp_push_list_sign = rsp_builder.Finish(); + fbb->Finish(rsp_push_list_sign); + return; +} + +REG_ROUND_KERNEL(pushListSign, PushListSignKernel) +} // namespace kernel +} // namespace server +} // namespace fl +} // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.h new file mode 100644 index 00000000000..54d168e5502 --- /dev/null +++ b/mindspore/ccsrc/fl/server/kernel/round/push_list_sign_kernel.h @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_LIST_SIGN_KERNEL_H +#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_LIST_SIGN_KERNEL_H + +#include +#include +#include +#include "fl/server/common.h" +#include "fl/server/kernel/round/round_kernel.h" +#include "fl/server/kernel/round/round_kernel_factory.h" +#include "fl/armour/cipher/cipher_init.h" +#include "fl/server/executor.h" + +namespace mindspore { +namespace fl { +namespace server { +namespace kernel { +// results of signature verification +enum sigVerifyResult { FAILED, TIMEOUT, PASSED }; + +class PushListSignKernel : public RoundKernel { + public: + PushListSignKernel() = default; + ~PushListSignKernel() override = default; + void InitKernel(size_t required_cnt) override; + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) override; + bool LaunchForPushListSign(const schema::SendClientListSign *client_list_sign_req, const size_t &iter_num, + const std::shared_ptr &fbb, const std::vector &outputs); + bool Reset() override; + void BuildPushListSignKernelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, + const string &reason, const string &next_req_time, const size_t iteration); + + private: + armour::CipherInit *cipher_init_; + Executor *executor_; + size_t iteration_time_window_; + sigVerifyResult VerifySignature(const schema::SendClientListSign *client_list_sign_req); + bool PushListSign(const size_t cur_iterator, const std::string &next_req_time, + const schema::SendClientListSign *client_list_sign_req, + const std::shared_ptr &fbb, + const std::vector &update_model_clients); +}; +} // namespace kernel +} // namespace server +} // namespace fl +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_PUSH_LIST_SIGN_KERNEL_H diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc index c1ba5dceae2..cbd19069a7f 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/push_metrics_kernel.cc @@ -98,7 +98,7 @@ ResultCode PushMetricsKernel::PushMetrics(const std::shared_ptr &fbb, void PushMetricsKernel::BuildPushMetricsRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode) { MS_ERROR_IF_NULL_WO_RET_VAL(fbb); schema::ResponsePushMetricsBuilder rsp_push_metrics_builder(*(fbb.get())); - rsp_push_metrics_builder.add_retcode(static_cast(retcode)); + rsp_push_metrics_builder.add_retcode(retcode); auto rsp_push_metrics = rsp_push_metrics_builder.Finish(); fbb->Finish(rsp_push_metrics); } diff --git a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc index 4afe9009701..833d91611b3 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/push_weight_kernel.cc @@ -33,6 +33,13 @@ void PushWeightKernel::InitKernel(size_t) { bool PushWeightKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { MS_LOG(INFO) << "Launching PushWeightKernel kernel."; + if (inputs.size() != 1 || outputs.size() != 1) { + std::string reason = "inputs or outputs size is invalid."; + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, reason.c_str(), reason.size()); + return true; + } + void *req_data = inputs[0]->addr; std::shared_ptr fbb = std::make_shared(); if (fbb == nullptr || req_data == nullptr) { @@ -141,7 +148,7 @@ void PushWeightKernel::BuildPushWeightRsp(const std::shared_ptr &fbb, } auto fbs_reason = fbb->CreateString(reason); schema::ResponsePushWeightBuilder rsp_push_weight_builder(*(fbb.get())); - rsp_push_weight_builder.add_retcode(static_cast(retcode)); + rsp_push_weight_builder.add_retcode(SizeToInt(retcode)); rsp_push_weight_builder.add_reason(fbs_reason); rsp_push_weight_builder.add_iteration(SizeToInt(iteration)); auto rsp_push_weight = rsp_push_weight_builder.Finish(); diff --git a/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc index 9e53916546d..1195504c584 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.cc @@ -18,6 +18,8 @@ #include #include #include +#include +#include namespace mindspore { namespace fl { @@ -27,7 +29,6 @@ void ReconstructSecretsKernel::InitKernel(size_t) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); } - auto last_cnt_handler = [&](std::shared_ptr) { if (ps::PSContext::instance()->resetter_round() == ps::ResetterRound::kReconstructSeccrets) { MS_LOG(INFO) << "start FinishIteration"; @@ -44,14 +45,52 @@ void ReconstructSecretsKernel::InitKernel(size_t) { {first_cnt_handler, last_cnt_handler}); } +sigVerifyResult ReconstructSecretsKernel::VerifySignature(const schema::SendReconstructSecret *reconstruct_secret_req) { + std::string fl_id = reconstruct_secret_req->fl_id()->str(); + std::string timestamp = reconstruct_secret_req->timestamp()->str(); + int iteration = reconstruct_secret_req->iteration(); + std::string iter_str = std::to_string(iteration); + auto fbs_signature = reconstruct_secret_req->signature(); + std::vector signature; + if (fbs_signature == nullptr) { + MS_LOG(ERROR) << "signature in reconstruct_secret_req is nullptr"; + return sigVerifyResult::FAILED; + } + signature.assign(fbs_signature->begin(), fbs_signature->end()); + std::map key_attestations; + const fl::PBMetadata &key_attestations_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation); + const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation(); + auto iter = key_attestation_pb.key_attestations().begin(); + for (; iter != key_attestation_pb.key_attestations().end(); ++iter) { + (void)key_attestations.emplace(std::pair(iter->first, iter->second)); + } + if (key_attestations.find(fl_id) == key_attestations.end()) { + MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id; + return sigVerifyResult::FAILED; + } + + std::vector src_data; + (void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end()); + (void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end()); + mindspore::ps::server::CertVerify certVerify; + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH); + if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) { + return sigVerifyResult::FAILED; + } + if (!certVerify.verifyTimeStamp(fl_id, timestamp)) { + return sigVerifyResult::TIMEOUT; + } + MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success."; + return sigVerifyResult::PASSED; +} + bool ReconstructSecretsKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { bool response = false; size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); - size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); - MS_LOG(INFO) << "Iteration number is " << iter_num << ", ReconstructSecretsKernel total duration is " - << total_duration; - clock_t start_time = clock(); + MS_LOG(INFO) << "Launching ReconstructSecrets Kernel, Iteration number is " << iter_num; if (inputs.size() != 1 || outputs.size() != 1) { MS_LOG(ERROR) << "ReconstructSecretsKernel needs 1 input, but got " << inputs.size(); @@ -73,14 +112,55 @@ bool ReconstructSecretsKernel::Launch(const std::vector &inputs, con DistributedMetadataStore::GetInstance().GetMetadata(kCtxUpdateModelClientList); const UpdateModelClientList &update_model_clients_pb = update_model_clients_pb_out.client_list(); - for (int i = 0; i < update_model_clients_pb.fl_id_size(); ++i) { + for (size_t i = 0; i < IntToSize(update_model_clients_pb.fl_id_size()); ++i) { update_model_clients.push_back(update_model_clients_pb.fl_id(i)); } - + flatbuffers::Verifier verifier(reinterpret_cast(req_data), inputs[0]->size); + if (!verifier.VerifyBuffer()) { + std::string reason = "The schema of SendReconstructSecret is invalid."; + cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num), + std::to_string(CURRENT_TIME_MILLI.count())); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } const schema::SendReconstructSecret *reconstruct_secret_req = flatbuffers::GetRoot(req_data); - std::string fl_id = reconstruct_secret_req->fl_id()->str(); + if (reconstruct_secret_req == nullptr) { + std::string reason = "Building flatbuffers schema failed for SendReconstructSecret."; + cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, SizeToInt(iter_num), + std::to_string(CURRENT_TIME_MILLI.count())); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + // verify signature + if (ps::PSContext::instance()->pki_verify()) { + sigVerifyResult verify_result = VerifySignature(reconstruct_secret_req); + if (verify_result == sigVerifyResult::FAILED) { + std::string reason = "verify signature failed."; + cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, + SizeToInt(iter_num), std::to_string(CURRENT_TIME_MILLI.count())); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + if (verify_result == sigVerifyResult::TIMEOUT) { + std::string reason = "verify signature timestamp failed."; + cipher_reconstruct_.BuildReconstructSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason, SizeToInt(iter_num), + std::to_string(CURRENT_TIME_MILLI.count())); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::PASSED) { + MS_LOG(INFO) << "verify signature passed!"; + } + } + + std::string fl_id = reconstruct_secret_req->fl_id()->str(); if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { MS_LOG(ERROR) << "Current amount for ReconstructSecretsKernel is enough."; if (find(update_model_clients.begin(), update_model_clients.end(), fl_id) != update_model_clients.end()) { @@ -106,11 +186,10 @@ bool ReconstructSecretsKernel::Launch(const std::vector &inputs, con MS_LOG(INFO) << "Current amount for ReconstructSecretsKernel is enough."; } GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - clock_t end_time = clock(); - double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); - MS_LOG(INFO) << "reconstruct_secrets_kernel success time is : " << duration; + + MS_LOG(INFO) << "reconstruct_secrets_kernel success."; if (!response) { - MS_LOG(INFO) << "reconstruct_secrets_kernel response is false."; + MS_LOG(INFO) << "reconstruct_secrets_kernel response not ready."; } return true; } diff --git a/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.h index d64bc9d14ae..f7ba4241cb1 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/reconstruct_secrets_kernel.h @@ -30,6 +30,9 @@ namespace mindspore { namespace fl { namespace server { namespace kernel { +// results of signature verification +enum sigVerifyResult { FAILED, TIMEOUT, PASSED }; + class ReconstructSecretsKernel : public RoundKernel { public: ReconstructSecretsKernel() = default; @@ -43,8 +46,10 @@ class ReconstructSecretsKernel : public RoundKernel { private: std::string name_unmask_; + Executor *executor_; size_t iteration_time_window_{0}; armour::CipherReconStruct cipher_reconstruct_; + sigVerifyResult VerifySignature(const schema::SendReconstructSecret *reconstruct_secret_req); }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc index e8002883aca..042323ab281 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.cc @@ -124,7 +124,7 @@ void RoundKernel::GenerateOutput(const std::vector &outputs, const v outputs[0]->size = len; std::unique_lock lock(heap_data_mtx_); - (void)heap_data_.emplace(outputs[0], std::move(output_data)); + (void)heap_data_.insert(std::make_pair(outputs[0], std::move(output_data))); return; } } // namespace kernel diff --git a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h index 645713ed3f9..c7184a89eda 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/round_kernel.h @@ -26,7 +26,7 @@ #include #include #include -#include "utils/hash_map.h" +#include #include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "fl/server/common.h" @@ -120,7 +120,7 @@ class RoundKernel : virtual public CPUKernel { std::mutex release_mtx_; std::queue heap_data_to_release_; std::mutex heap_data_mtx_; - mindspore::HashMap> heap_data_; + std::unordered_map> heap_data_; }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/round/round_kernel_factory.h b/mindspore/ccsrc/fl/server/kernel/round/round_kernel_factory.h index ac02c8794f3..c3a7f675139 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/round_kernel_factory.h +++ b/mindspore/ccsrc/fl/server/kernel/round/round_kernel_factory.h @@ -20,8 +20,9 @@ #include #include #include -#include "utils/hash_map.h" +#include #include "fl/server/common.h" +#include "fl/server/cert_verify.h" #include "fl/server/kernel/round/round_kernel.h" namespace mindspore { @@ -42,7 +43,7 @@ class RoundKernelFactory { RoundKernelFactory(const RoundKernelFactory &) = delete; RoundKernelFactory &operator=(const RoundKernelFactory &) = delete; - mindspore::HashMap name_to_creator_map_; + std::unordered_map name_to_creator_map_; }; class RoundKernelRegister { diff --git a/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc index 81aa80e8c08..25b954b8ea8 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.cc @@ -17,6 +17,8 @@ #include "fl/server/kernel/round/share_secrets_kernel.h" #include #include +#include +#include namespace mindspore { namespace fl { @@ -26,13 +28,6 @@ void ShareSecretsKernel::InitKernel(size_t) { if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { iteration_time_window_ = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); } - - executor_ = &Executor::GetInstance(); - MS_EXCEPTION_IF_NULL(executor_); - if (!executor_->initialized()) { - MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; - return; - } cipher_share_ = &armour::CipherShares::GetInstance(); } @@ -49,21 +44,58 @@ bool ShareSecretsKernel::CountForShareSecrets(const std::shared_ptr & return true; } +sigVerifyResult ShareSecretsKernel::VerifySignature(const schema::RequestShareSecrets *share_secrets_req) { + std::string fl_id = share_secrets_req->fl_id()->str(); + std::string timestamp = share_secrets_req->timestamp()->str(); + int iteration = share_secrets_req->iteration(); + std::string iter_str = std::to_string(iteration); + auto fbs_signature = share_secrets_req->signature(); + std::vector signature; + if (fbs_signature == nullptr) { + MS_LOG(ERROR) << "signature in share_secrets_req is nullptr"; + return sigVerifyResult::FAILED; + } + signature.assign(fbs_signature->begin(), fbs_signature->end()); + std::map key_attestations; + const fl::PBMetadata &key_attestations_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation); + const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation(); + auto iter = key_attestation_pb.key_attestations().begin(); + for (; iter != key_attestation_pb.key_attestations().end(); ++iter) { + (void)key_attestations.emplace(std::pair(iter->first, iter->second)); + } + if (key_attestations.find(fl_id) == key_attestations.end()) { + MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id; + return sigVerifyResult::FAILED; + } + + std::vector src_data; + (void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end()); + (void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end()); + mindspore::ps::server::CertVerify certVerify; + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH); + if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) { + return sigVerifyResult::FAILED; + } + if (!certVerify.verifyTimeStamp(fl_id, timestamp)) { + return sigVerifyResult::TIMEOUT; + } + MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success."; + return sigVerifyResult::PASSED; +} + bool ShareSecretsKernel::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { bool response = false; size_t iter_num = LocalMetaStore::GetInstance().curr_iter_num(); - size_t total_duration = LocalMetaStore::GetInstance().value(kCtxTotalTimeoutDuration); - MS_LOG(INFO) << "ITERATION NUMBER IS : " << iter_num << ", Total ShareSecretsKernel allowed Duration Is " - << total_duration; - clock_t start_time = clock(); + MS_LOG(INFO) << "Launching ShareSecretsKernel, ITERATION NUMBER IS : " << iter_num; if (inputs.size() != 1 || outputs.size() != 1) { std::string reason = "inputs or outputs size is invalid."; MS_LOG(ERROR) << reason; return false; } - std::shared_ptr fbb = std::make_shared(); void *req_data = inputs[0]->addr; if (fbb == nullptr || req_data == nullptr) { @@ -71,7 +103,6 @@ bool ShareSecretsKernel::Launch(const std::vector &inputs, const std MS_LOG(ERROR) << reason; return false; } - if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { MS_LOG(ERROR) << "Current amount for ShareSecretsKernel is enough."; cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, @@ -80,7 +111,50 @@ bool ShareSecretsKernel::Launch(const std::vector &inputs, const std GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } + flatbuffers::Verifier verifier(reinterpret_cast(req_data), inputs[0]->size); + if (!verifier.VerifyBuffer()) { + std::string reason = "The schema of RequestShareSecrets is invalid."; + cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } const schema::RequestShareSecrets *share_secrets_req = flatbuffers::GetRoot(req_data); + if (share_secrets_req == nullptr) { + std::string reason = "Building flatbuffers schema failed for RequestShareSecrets."; + cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + // verify signature + if (ps::PSContext::instance()->pki_verify()) { + sigVerifyResult verify_result = VerifySignature(share_secrets_req); + if (verify_result == sigVerifyResult::FAILED) { + std::string reason = "verify signature failed."; + cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_RequestError, reason, + std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::TIMEOUT) { + std::string reason = "verify signature timestamp failed."; + cipher_share_->BuildShareSecretsRsp(fbb, schema::ResponseCode_OutOfTime, reason, + std::to_string(CURRENT_TIME_MILLI.count()), SizeToInt(iter_num)); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::PASSED) { + MS_LOG(INFO) << "verify signature passed!"; + } + } + size_t iter_client = IntToSize(share_secrets_req->iteration()); if (iter_num != iter_client) { MS_LOG(ERROR) << "ShareSecretsKernel iteration invalid. server now iteration is " << iter_num @@ -93,7 +167,7 @@ bool ShareSecretsKernel::Launch(const std::vector &inputs, const std response = cipher_share_->ShareSecrets(SizeToInt(iter_num), share_secrets_req, fbb, std::to_string(CURRENT_TIME_MILLI.count())); if (!response) { - MS_LOG(WARNING) << "update secret shares is failed."; + MS_LOG(ERROR) << "update secret shares is failed."; GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); return true; } @@ -102,9 +176,6 @@ bool ShareSecretsKernel::Launch(const std::vector &inputs, const std return true; } GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); - clock_t end_time = clock(); - double duration = static_cast((end_time - start_time) * 1.0 / CLOCKS_PER_SEC); - MS_LOG(INFO) << "share_secrets_kernel success time is : " << duration; return true; } diff --git a/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.h index e94b4f62b9e..3706e5bc043 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/share_secrets_kernel.h @@ -30,6 +30,9 @@ namespace mindspore { namespace fl { namespace server { namespace kernel { +// results of signature verification +enum sigVerifyResult { FAILED, TIMEOUT, PASSED }; + class ShareSecretsKernel : public RoundKernel { public: ShareSecretsKernel() = default; @@ -43,6 +46,7 @@ class ShareSecretsKernel : public RoundKernel { Executor *executor_; size_t iteration_time_window_; armour::CipherShares *cipher_share_; + sigVerifyResult VerifySignature(const schema::RequestShareSecrets *share_secrets_req); bool CountForShareSecrets(const std::shared_ptr &fbb, const schema::RequestShareSecrets *share_secrets_req, const size_t iter_num); }; diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc index a7d24eb5d41..6abf5e07f6d 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.cc @@ -19,11 +19,11 @@ #include #include #include -#include "fl/server/model_store.h" -#include "fl/server/iteration.h" #ifdef ENABLE_ARMOUR #include "fl/armour/cipher/cipher_init.h" #endif +#include "fl/server/model_store.h" +#include "fl/server/iteration.h" namespace mindspore { namespace fl { @@ -46,6 +46,9 @@ void StartFLJobKernel::InitKernel(size_t) { PBMetadata devices_metas; DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas); + + PBMetadata client_key_attestation; + DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxClientKeyAttestation, client_key_attestation); return; } @@ -93,6 +96,17 @@ bool StartFLJobKernel::Launch(const std::vector &inputs, const std:: return true; } + if (ps::PSContext::instance()->pki_verify()) { + if (!JudgeFLJobCert(fbb, start_fl_job_req)) { + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + if (!StoreKeyAttestation(fbb, start_fl_job_req)) { + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + } + DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req); result_code = ReadyForStartFLJob(fbb, device_meta); if (result_code != ResultCode::kSuccess) { @@ -122,16 +136,87 @@ bool StartFLJobKernel::Launch(const std::vector &inputs, const std:: return true; } +bool StartFLJobKernel::JudgeFLJobCert(const std::shared_ptr &fbb, + const schema::RequestFLJob *start_fl_job_req) { + std::string fl_id = start_fl_job_req->fl_id()->str(); + std::string timestamp = start_fl_job_req->timestamp()->str(); + auto sign_data_vector = start_fl_job_req->sign_data(); + if (sign_data_vector == nullptr || sign_data_vector->size() == 0) { + std::string reason = "sign data is empty."; + BuildStartFLJobRsp( + fbb, schema::ResponseCode_RequestError, reason, false, + std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); + MS_LOG(ERROR) << reason; + return false; + } + unsigned char sign_data[sign_data_vector->size()]; + + for (unsigned int i = 0; i < sign_data_vector->size(); i++) { + sign_data[i] = sign_data_vector->Get(i); + } + + std::string key_attestation = start_fl_job_req->key_attestation()->str(); + std::string equip_cert = start_fl_job_req->equip_cert()->str(); + std::string equip_ca_cert = start_fl_job_req->equip_ca_cert()->str(); + std::string root_first_ca_path = ps::PSContext::instance()->root_first_ca_path(); + std::string root_second_ca_path = ps::PSContext::instance()->root_second_ca_path(); + std::string equip_crl_path = ps::PSContext::instance()->equip_crl_path(); + + mindspore::ps::server::CertVerify certVerify; + bool ret = + certVerify.verifyCertAndSign(fl_id, timestamp, (const unsigned char *)sign_data, key_attestation, equip_cert, + equip_ca_cert, root_first_ca_path, root_second_ca_path, equip_crl_path); + if (!ret) { + std::string reason = "startFLJob sign and certificate verify failed."; + BuildStartFLJobRsp( + fbb, schema::ResponseCode_RequestError, reason, false, + std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); + MS_LOG(ERROR) << reason; + } else { + MS_LOG(INFO) << "JudgeFLJobVerify success." << ret; + } + + return ret; +} + +bool StartFLJobKernel::StoreKeyAttestation(const std::shared_ptr &fbb, + const schema::RequestFLJob *start_fl_job_req) { + // update key attestation + if (start_fl_job_req == nullptr) { + return false; + } + std::string fl_id = start_fl_job_req->fl_id()->str(); + std::string key_attestation = start_fl_job_req->key_attestation()->str(); + + fl::PairKeyAttestation pair_key_attestation_pb; + pair_key_attestation_pb.set_fl_id(fl_id); + pair_key_attestation_pb.set_certificate(key_attestation); + + fl::PBMetadata pb_data; + pb_data.mutable_pair_key_attestation()->MergeFrom(pair_key_attestation_pb); + bool ret = fl::server::DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxClientKeyAttestation, pb_data); + if (!ret) { + std::string reason = "startFLJob: store key attestation failed"; + MS_LOG(ERROR) << reason; + BuildStartFLJobRsp( + fbb, schema::ResponseCode_OutOfTime, reason, false, + std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); + return false; + } + return true; +} + bool StartFLJobKernel::Reset() { MS_LOG(INFO) << "Starting fl job kernel reset!"; StopTimer(); DistributedCountService::GetInstance().ResetCounter(name_); DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas); + DistributedMetadataStore::GetInstance().ResetMetadata(kCtxClientKeyAttestation); return true; } void StartFLJobKernel::OnFirstCountEvent(const std::shared_ptr &) { - iter_next_req_timestamp_ = LongToSize(CURRENT_TIME_MILLI.count()) + iteration_time_window_; + iter_next_req_timestamp_ = LongToUlong(CURRENT_TIME_MILLI.count()) + iteration_time_window_; LocalMetaStore::GetInstance().put_value(kCtxIterationNextRequestTimestamp, iter_next_req_timestamp_); // The first startFLJob request means a new iteration starts running. Iteration::GetInstance().SetIterationRunning(); @@ -241,9 +326,11 @@ void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr &fbb, fl_plan_builder.add_epochs(SizeToInt(ps::PSContext::instance()->client_epoch_num())); fl_plan_builder.add_mini_batch(SizeToInt(ps::PSContext::instance()->client_batch_size())); fl_plan_builder.add_lr(ps::PSContext::instance()->client_learning_rate()); + #ifdef ENABLE_ARMOUR fl_plan_builder.add_cipher(cipher_public_params); #endif + auto fbs_fl_plan = fl_plan_builder.Finish(); std::vector> fbs_feature_maps; diff --git a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h index c6f5cd0ba08..c0a4f4ff7a7 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/start_fl_job_kernel.h @@ -58,6 +58,10 @@ class StartFLJobKernel : public RoundKernel { void StartFLJob(const std::shared_ptr &fbb, const DeviceMeta &device_meta); + bool JudgeFLJobCert(const std::shared_ptr &fbb, const schema::RequestFLJob *start_fl_job_req); + + bool StoreKeyAttestation(const std::shared_ptr &fbb, const schema::RequestFLJob *start_fl_job_req); + // Build response for startFLJob round no matter success or failure. void BuildStartFLJobRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const bool is_selected, const std::string &next_req_time, diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc index 54082ac2bdd..ffa5875a12b 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "fl/server/kernel/round/update_model_kernel.h" namespace mindspore { @@ -85,6 +86,29 @@ bool UpdateModelKernel::Launch(const std::vector &inputs, const std: return true; } + // verify signature + if (ps::PSContext::instance()->pki_verify()) { + sigVerifyResult verify_result = VerifySignature(update_model_req); + if (verify_result == sigVerifyResult::FAILED) { + std::string reason = "verify signature failed."; + BuildUpdateModelRsp(fbb, schema::ResponseCode_RequestError, reason, ""); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + + if (verify_result == sigVerifyResult::TIMEOUT) { + std::string reason = "verify signature timestamp failed."; + BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, ""); + MS_LOG(ERROR) << reason; + GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); + return true; + } + if (verify_result == sigVerifyResult::PASSED) { + MS_LOG(INFO) << "verify signature passed!"; + } + } + result_code = UpdateModel(update_model_req, fbb); if (result_code != ResultCode::kSuccess) { MS_LOG(ERROR) << "Updating model failed."; @@ -144,12 +168,11 @@ ResultCode UpdateModelKernel::UpdateModel(const schema::RequestUpdateModel *upda MS_ERROR_IF_NULL_W_RET_VAL(update_model_req, ResultCode::kSuccessAndReturn); size_t iteration = IntToSize(update_model_req->iteration()); if (iteration != LocalMetaStore::GetInstance().curr_iter_num()) { + auto next_req_time = LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp); std::string reason = "UpdateModel iteration number is invalid:" + std::to_string(iteration) + ", current iteration:" + std::to_string(LocalMetaStore::GetInstance().curr_iter_num()) + - ". Retry later."; - BuildUpdateModelRsp( - fbb, schema::ResponseCode_OutOfTime, reason, - std::to_string(LocalMetaStore::GetInstance().value(kCtxIterationNextRequestTimestamp))); + ". Retry later at time: " + std::to_string(next_req_time); + BuildUpdateModelRsp(fbb, schema::ResponseCode_OutOfTime, reason, std::to_string(next_req_time)); MS_LOG(WARNING) << reason; return ResultCode::kSuccessAndReturn; } @@ -259,6 +282,47 @@ ResultCode UpdateModelKernel::CountForUpdateModel(const std::shared_ptrfl_id()->str(); + std::string timestamp = update_model_req->timestamp()->str(); + int iteration = update_model_req->iteration(); + std::string iter_str = std::to_string(iteration); + auto fbs_signature = update_model_req->signature(); + std::vector signature; + if (fbs_signature == nullptr) { + MS_LOG(ERROR) << "signature in client_list_sign_req is nullptr"; + return sigVerifyResult::FAILED; + } + signature.assign(fbs_signature->begin(), fbs_signature->end()); + std::map key_attestations; + const fl::PBMetadata &key_attestations_pb_out = + fl::server::DistributedMetadataStore::GetInstance().GetMetadata(kCtxClientKeyAttestation); + const fl::KeyAttestation &key_attestation_pb = key_attestations_pb_out.key_attestation(); + auto iter = key_attestation_pb.key_attestations().begin(); + for (; iter != key_attestation_pb.key_attestations().end(); ++iter) { + (void)key_attestations.emplace(std::pair(iter->first, iter->second)); + } + if (key_attestations.find(fl_id) == key_attestations.end()) { + MS_LOG(ERROR) << "can not find key attestation for fl_id: " << fl_id; + return sigVerifyResult::TIMEOUT; + } + + std::vector src_data; + (void)src_data.insert(src_data.end(), timestamp.begin(), timestamp.end()); + (void)src_data.insert(src_data.end(), iter_str.begin(), iter_str.end()); + mindspore::ps::server::CertVerify certVerify; + unsigned char srcDataHash[SHA256_DIGEST_LENGTH]; + certVerify.sha256Hash(src_data.data(), SizeToInt(src_data.size()), srcDataHash, SHA256_DIGEST_LENGTH); + if (!certVerify.verifyRSAKey(key_attestations[fl_id], srcDataHash, signature.data(), SHA256_DIGEST_LENGTH)) { + return sigVerifyResult::FAILED; + } + if (!certVerify.verifyTimeStamp(fl_id, timestamp)) { + return sigVerifyResult::TIMEOUT; + } + MS_LOG(INFO) << "verify signature for fl_id: " << fl_id << " success."; + return sigVerifyResult::PASSED; +} + void UpdateModelKernel::BuildUpdateModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const std::string &next_req_time) { if (fbb == nullptr) { diff --git a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h index 3c9164c1d29..0440608e8d8 100644 --- a/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h +++ b/mindspore/ccsrc/fl/server/kernel/round/update_model_kernel.h @@ -35,10 +35,12 @@ namespace server { namespace kernel { // The initial data size sum of federated learning is 0, which will be accumulated in updateModel round. constexpr uint64_t kInitialDataSizeSum = 0; +// results of signature verification +enum sigVerifyResult { FAILED, TIMEOUT, PASSED }; class UpdateModelKernel : public RoundKernel { public: - UpdateModelKernel() : executor_(nullptr), iteration_time_window_(0) {} + UpdateModelKernel() = default; ~UpdateModelKernel() override = default; void InitKernel(size_t threshold_count) override; @@ -55,6 +57,7 @@ class UpdateModelKernel : public RoundKernel { std::map ParseFeatureMap(const schema::RequestUpdateModel *update_model_req); ResultCode CountForUpdateModel(const std::shared_ptr &fbb, const schema::RequestUpdateModel *update_model_req); + sigVerifyResult VerifySignature(const schema::RequestUpdateModel *update_model_req); void BuildUpdateModelRsp(const std::shared_ptr &fbb, const schema::ResponseCode retcode, const std::string &reason, const std::string &next_req_time); @@ -62,7 +65,7 @@ class UpdateModelKernel : public RoundKernel { Executor *executor_; // The time window of one iteration. - size_t iteration_time_window_; + size_t iteration_time_window_{0}; }; } // namespace kernel } // namespace server diff --git a/mindspore/ccsrc/fl/server/kernel/sgd_kernel.cc b/mindspore/ccsrc/fl/server/kernel/sgd_kernel.cc deleted file mode 100644 index 0b1d13673c0..00000000000 --- a/mindspore/ccsrc/fl/server/kernel/sgd_kernel.cc +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "fl/server/kernel/sgd_kernel.h" - -namespace mindspore { -namespace fl { -namespace server { -namespace kernel { -REG_OPTIMIZER_KERNEL(SGD, - ParamsInfo() - .AddInputNameType(kWeight, kNumberTypeFloat32) - .AddInputNameType(kGradient, kNumberTypeFloat32) - .AddInputNameType(kLearningRate, kNumberTypeFloat32) - .AddInputNameType(kAccumulation, kNumberTypeFloat32) - .AddInputNameType(kMomentum, kNumberTypeFloat32) - .AddInputNameType(kStat, kNumberTypeFloat32), - SGDKernel, float) -} // namespace kernel -} // namespace server -} // namespace fl -} // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/kernel/sgd_kernel.h b/mindspore/ccsrc/fl/server/kernel/sgd_kernel.h deleted file mode 100644 index 233f9c0d8df..00000000000 --- a/mindspore/ccsrc/fl/server/kernel/sgd_kernel.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_ -#define MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_ - -#include -#include -#include -#include "backend/kernel_compiler/cpu/sgd_cpu_kernel.h" -#include "fl/server/kernel/optimizer_kernel.h" -#include "fl/server/kernel/optimizer_kernel_factory.h" - -namespace mindspore { -namespace fl { -namespace server { -namespace kernel { -using mindspore::kernel::SGDCPUKernel; -template -class SGDKernel : public SGDCPUKernel, public OptimizerKernel { - public: - SGDKernel() = default; - ~SGDKernel() override = default; - - void InitKernel(const CNodePtr &cnode) override { - SGDCPUKernel::InitKernel(cnode); - InitServerKernelInputOutputSize(cnode); - GenerateReuseKernelNodeInfo(); - } - - bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs) override { - return SGDCPUKernel::Launch(inputs, workspace, outputs); - } - - void GenerateReuseKernelNodeInfo() override { - MS_LOG(INFO) << "SGD reuse 'weight', 'learning rate', 'accumulation', 'momentum' and 'stat' of the kernel node."; - reuse_kernel_node_inputs_info_.insert(std::make_pair(kWeight, 0)); - reuse_kernel_node_inputs_info_.insert(std::make_pair(kLearningRate, 2)); - reuse_kernel_node_inputs_info_.insert(std::make_pair(kAccumulation, 3)); - reuse_kernel_node_inputs_info_.insert(std::make_pair(kMomentum, 4)); - reuse_kernel_node_inputs_info_.insert(std::make_pair(kStat, 5)); - return; - } -}; -} // namespace kernel -} // namespace server -} // namespace fl -} // namespace mindspore -#endif // MINDSPORE_CCSRC_FL_SERVER_KERNEL_SGD_KERNEL_H_ diff --git a/mindspore/ccsrc/fl/server/local_meta_store.h b/mindspore/ccsrc/fl/server/local_meta_store.h index 98b6b642a31..04c92e6827a 100644 --- a/mindspore/ccsrc/fl/server/local_meta_store.h +++ b/mindspore/ccsrc/fl/server/local_meta_store.h @@ -20,7 +20,7 @@ #include #include #include -#include "utils/hash_map.h" +#include #include "fl/server/common.h" namespace mindspore { @@ -77,10 +77,10 @@ class LocalMetaStore { LocalMetaStore &operator=(const LocalMetaStore &) = delete; // key_to_meta_ stores metadata with key-value format. - mindspore::HashMap key_to_meta_; + std::unordered_map key_to_meta_; // This mutex makes sure that the operations on key_to_meta_ is threadsafe. std::mutex mtx_; - size_t curr_iter_num_; + size_t curr_iter_num_{0}; }; } // namespace server } // namespace fl diff --git a/mindspore/ccsrc/fl/server/model_store.h b/mindspore/ccsrc/fl/server/model_store.h index b8d38829ef8..764226f4cfc 100644 --- a/mindspore/ccsrc/fl/server/model_store.h +++ b/mindspore/ccsrc/fl/server/model_store.h @@ -31,7 +31,7 @@ namespace server { constexpr size_t kInitIterationNum = 0; // The initial iteration number after ModelStore is reset. -constexpr size_t kResetInitIterNum = 1; +constexpr size_t kResetInitialIterNum = 1; // Server framework use ModelStore to store and query models. // ModelStore stores multiple models because worker could get models of the previous iterations. diff --git a/mindspore/ccsrc/fl/server/parameter_aggregator.cc b/mindspore/ccsrc/fl/server/parameter_aggregator.cc index 11db55bb58e..375e1043772 100644 --- a/mindspore/ccsrc/fl/server/parameter_aggregator.cc +++ b/mindspore/ccsrc/fl/server/parameter_aggregator.cc @@ -111,39 +111,6 @@ bool ParameterAggregator::LaunchAggregators() { return true; } -bool ParameterAggregator::LaunchOptimizers() { - for (auto &optimizer_with_params : optimizer_kernel_parameters_) { - KernelParams ¶ms = optimizer_with_params.second; - std::shared_ptr optimizer_kernel = optimizer_with_params.first; - MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false); - bool ret = optimizer_kernel->Launch(params.inputs, params.workspace, params.outputs); - if (!ret) { - MS_LOG(ERROR) << "Launching optimizer kernel " << typeid(optimizer_kernel.get()).name() << " failed."; - continue; - } - } - // As long as all the optimizer kernels are launched, consider optimizing for this ParameterAggregator as done. - optimizing_done_ = true; - return true; -} - -AddressPtr ParameterAggregator::Pull() { - if (memory_register_ == nullptr) { - MS_LOG(ERROR) - << "The memory register of ParameterAggregator is nullptr. Please initialize ParameterAggregator first."; - return nullptr; - } - - current_pull_count_++; - if (current_pull_count_ == required_pull_count_) { - pulling_done_ = true; - } - MS_LOG(DEBUG) << "The " << current_pull_count_ << " time of Pull. Pulling done status: " << pulling_done_; - - std::map &name_to_addr = memory_register_->addresses(); - return name_to_addr["weight"]; -} - AddressPtr ParameterAggregator::GetWeight() { if (memory_register_ == nullptr) { MS_LOG(ERROR) @@ -193,8 +160,8 @@ bool ParameterAggregator::requires_aggr() const { return requires_aggr_; } bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); - if (!JudgeRequiresAggr(cnode)) { - MS_LOG(WARNING) << "Aggregation for weight for kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required."; + if (!JudgeRequiredAggr(cnode)) { + MS_LOG(WARNING) << "Aggregation for weight of kernel " << AnfAlgo::GetCNodeName(cnode) << " is not required."; } std::vector aggr_kernel_names = SelectAggregationAlgorithm(cnode); @@ -223,32 +190,12 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) { return true; } -bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { +bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &) { if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL || ps::PSContext::instance()->server_mode() == ps::kServerModeHybrid) { MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel."; return true; } - MS_EXCEPTION_IF_NULL(cnode); - const std::string &name = AnfAlgo::GetCNodeName(cnode); - auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode); - if (optimizer_kernel == nullptr) { - MS_LOG(EXCEPTION) << "Failed to create optimizer kernel for " << name; - return false; - } - - optimizer_kernel->InitKernel(cnode); - - const ReuseKernelNodeInfo &reuse_kernel_node_inputs_info = optimizer_kernel->reuse_kernel_node_inputs_info(); - if (!AssignMemory(optimizer_kernel, cnode, reuse_kernel_node_inputs_info, memory_register_)) { - MS_LOG(EXCEPTION) << "Assigning memory for kernel " << name << " failed."; - return false; - } - - if (!GenerateOptimizerKernelParams(optimizer_kernel, memory_register_)) { - MS_LOG(ERROR) << "Generating optimizer kernel parameters failed."; - return false; - } return true; } @@ -323,29 +270,6 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr< return true; } -bool ParameterAggregator::GenerateOptimizerKernelParams( - const std::shared_ptr &optimizer_kernel, - const std::shared_ptr &memory_register) { - MS_ERROR_IF_NULL_W_RET_VAL(optimizer_kernel, false); - MS_ERROR_IF_NULL_W_RET_VAL(memory_register, false); - KernelParams optimizer_params = {}; - - const std::vector &input_names = optimizer_kernel->input_names(); - (void)std::transform(input_names.begin(), input_names.end(), std::back_inserter(optimizer_params.inputs), - [&](const std::string &name) { return memory_register->addresses()[name]; }); - - const std::vector &workspace_names = optimizer_kernel->workspace_names(); - (void)std::transform(workspace_names.begin(), workspace_names.end(), std::back_inserter(optimizer_params.workspace), - [&](const std::string &name) { return memory_register->addresses()[name]; }); - - const std::vector &output_names = optimizer_kernel->output_names(); - (void)std::transform(output_names.begin(), output_names.end(), std::back_inserter(optimizer_params.outputs), - [&](const std::string &name) { return memory_register->addresses()[name]; }); - - optimizer_kernel_parameters_.push_back(std::make_pair(optimizer_kernel, optimizer_params)); - return true; -} - std::vector ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &) { std::vector aggregation_algorithm = {}; if (ps::PSContext::instance()->server_mode() == ps::kServerModeFL || @@ -362,7 +286,7 @@ std::vector ParameterAggregator::SelectAggregationAlgorithm(const C return aggregation_algorithm; } -bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) { +bool ParameterAggregator::JudgeRequiredAggr(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); std::string cnode_name = AnfAlgo::GetCNodeName(cnode); if (kNameToIdxMap.count(cnode_name) == 0 || kNameToIdxMap.at(cnode_name).count("inputs") == 0 || @@ -375,7 +299,7 @@ bool ParameterAggregator::JudgeRequiresAggr(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(weight_node); if (!weight_node->isa()) { - MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter node."; + MS_LOG(EXCEPTION) << weight_node->fullname_with_scope() << " is not a parameter."; return false; } auto param_info = weight_node->cast()->param_info(); diff --git a/mindspore/ccsrc/fl/server/parameter_aggregator.h b/mindspore/ccsrc/fl/server/parameter_aggregator.h index 4769b111e78..a67cdf2c9af 100644 --- a/mindspore/ccsrc/fl/server/parameter_aggregator.h +++ b/mindspore/ccsrc/fl/server/parameter_aggregator.h @@ -77,11 +77,6 @@ class ParameterAggregator { // Launch aggregators/optimizers of this ParameterAggregator in order. bool LaunchAggregators(); - bool LaunchOptimizers(); - - // The implementation for primitive Pull in parameter server training mode. - // Every call of this method will increase the count for pull by 1. - AddressPtr Pull(); // Different from the method Pull, this method simply returns the weight of this ParameterAggregator without causing // any change of status. @@ -98,7 +93,6 @@ class ParameterAggregator { bool IsOptimizingDone() const; bool IsPullingDone() const; - // Return whether this parameter requires aggragation. bool requires_aggr() const; private: @@ -119,15 +113,13 @@ class ParameterAggregator { // memory_register. bool GenerateAggregationKernelParams(const std::shared_ptr &aggr_kernel, const std::shared_ptr &memory_register); - bool GenerateOptimizerKernelParams(const std::shared_ptr &optim_kernel, - const std::shared_ptr &memory_register); // The selection of the aggregation algorithm depends on multiple factors. For example, server mode, user // configuration, etc. std::vector SelectAggregationAlgorithm(const CNodePtr &cnode); // Judge whether the parameter needs to be aggregated. - bool JudgeRequiresAggr(const CNodePtr &cnode); + bool JudgeRequiredAggr(const CNodePtr &cnode); ServerMode server_mode_; size_t required_push_count_; diff --git a/mindspore/ccsrc/fl/server/round.cc b/mindspore/ccsrc/fl/server/round.cc index 75652ddd725..261c62f7999 100644 --- a/mindspore/ccsrc/fl/server/round.cc +++ b/mindspore/ccsrc/fl/server/round.cc @@ -38,17 +38,16 @@ void Round::Initialize(const std::shared_ptr &commun const FinishIterCb &finish_iteration_cb) { MS_EXCEPTION_IF_NULL(communicator); communicator_ = communicator; - - // Register callback for round kernel. + MS_LOG(INFO) << "Round " << name_ << " start initialize."; communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr message) { MS_ERROR_IF_NULL_WO_RET_VAL(message); LaunchRoundKernel(message); }); // Callback when the iteration is finished. - finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void { + finish_iteration_cb_ = [this, finish_iteration_cb](bool, const std::string &) -> void { std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration."; - finish_iteration_cb(is_iteration_valid, reason); + finish_iteration_cb(true, reason); }; // Callback for finalizing the server. This can only be called once. @@ -62,9 +61,9 @@ void Round::Initialize(const std::shared_ptr &commun MS_EXCEPTION_IF_NULL(iter_timer_); // 1.Set the timeout callback for the timer. - iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void { + iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool, const std::string &) -> void { std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration."; - timeout_cb(is_iteration_valid, reason); + timeout_cb(false, reason); }); // 2.Stopping timer callback which will be set to the round kernel. @@ -139,7 +138,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr &m return; } - ++Iteration::GetInstance().running_round_num_; + (void)(Iteration::GetInstance().running_round_num_++); AddressPtr input = std::make_shared
(); AddressPtr output = std::make_shared
(); MS_ERROR_IF_NULL_WO_RET_VAL(input); @@ -167,7 +166,7 @@ void Round::LaunchRoundKernel(const std::shared_ptr &m reason = "Launching round kernel of round " + name_ + " failed."; Iteration::GetInstance().NotifyNext(false, reason); } - --Iteration::GetInstance().running_round_num_; + (void)(Iteration::GetInstance().running_round_num_--); return; } diff --git a/mindspore/ccsrc/fl/server/server.cc b/mindspore/ccsrc/fl/server/server.cc index 754918933f6..783d40e4b6f 100644 --- a/mindspore/ccsrc/fl/server/server.cc +++ b/mindspore/ccsrc/fl/server/server.cc @@ -32,7 +32,7 @@ namespace mindspore { namespace fl { namespace server { -// The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster managers like K8S. +// The handler to capture the signal of SIGTERM. Normally this signal is triggered by cloud cluster manager like K8S. std::shared_ptr g_communicator_with_server = nullptr; std::vector> g_communicators_with_worker = {}; void SignalHandler(int signal) { @@ -45,7 +45,6 @@ void SignalHandler(int signal) { MS_ERROR_IF_NULL_WO_RET_VAL(g_communicator_with_server); (void)g_communicator_with_server->Stop(); - return; } void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const std::vector &rounds_config, @@ -68,24 +67,10 @@ void Server::Initialize(bool use_tcp, bool use_http, uint16_t http_port, const s return; } -// Each step of the server pipeline may have dependency on other steps, which includes: - -// InitServerContext must be the first step to set contexts for later steps. - -// Server Running relies on URL or Message Type Register: -// StartCommunicator---->InitIteration - -// Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion: -// RegisterRoundKernel---->StartCommunicator - -// Kernel Initialization relies on Executor Initialization: -// RegisterRoundKernel---->InitExecutor - -// Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization: -// InitCipher---->InitExecutor void Server::Run() { std::unique_lock lock(scaling_mtx_); InitServerContext(); + InitPkiCertificate(); InitCluster(); InitIteration(); RegisterCommCallbacks(); @@ -98,6 +83,7 @@ void Server::Run() { } RegisterRoundKernel(); InitMetrics(); + Recover(); MS_LOG(INFO) << "Server started successfully."; safemode_ = false; lock.unlock(); @@ -111,9 +97,27 @@ void Server::Run() { MS_EXCEPTION_IF_NULL(communicator_with_server_); communicator_with_server_->Join(); MsException::Instance().CheckException(); + func_graph_ = nullptr; return; } +void Server::InitPkiCertificate() { + if (ps::PSContext::instance()->pki_verify()) { + root_first_ca_path_ = ps::PSContext::instance()->root_first_ca_path(); + root_second_ca_path_ = ps::PSContext::instance()->root_second_ca_path(); + equip_crl_path_ = ps::PSContext::instance()->equip_crl_path(); + replay_attack_time_diff_ = ps::PSContext::instance()->replay_attack_time_diff(); + + bool ret = mindspore::ps::server::CertVerify::initRootCertAndCRL(root_first_ca_path_, root_second_ca_path_, + equip_crl_path_, replay_attack_time_diff_); + if (!ret) { + MS_LOG(EXCEPTION) << "init root cert and crl failed."; + return; + } + return; + } +} + void Server::SwitchToSafeMode() { MS_LOG(INFO) << "Server switch to safemode."; safemode_ = true; @@ -138,11 +142,6 @@ void Server::InitServerContext() { scheduler_port_ = ps::PSContext::instance()->scheduler_port(); worker_num_ = ps::PSContext::instance()->initial_worker_num(); server_num_ = ps::PSContext::instance()->initial_server_num(); - std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); - if (encrypt_type == ps::kPWEncryptType && server_num_ > 1) { - MS_LOG(EXCEPTION) << "Only single server is supported for PW_ENCRYPT now, but got server_num is:." << server_num_; - return; - } return; } @@ -218,6 +217,8 @@ void Server::InitIteration() { cipher_share_secrets_cnt_ = cipher_config_.share_secrets_threshold; cipher_get_secrets_cnt_ = cipher_config_.get_secrets_threshold; cipher_get_clientlist_cnt_ = cipher_config_.client_list_threshold; + cipher_push_list_sign_cnt_ = cipher_config_.push_list_sign_threshold; + cipher_get_list_sign_cnt_ = cipher_config_.get_list_sign_threshold; cipher_reconstruct_secrets_up_cnt_ = cipher_config_.reconstruct_secrets_threshold; cipher_reconstruct_secrets_down_cnt_ = cipher_config_.reconstruct_secrets_threshold - 1; cipher_time_window_ = cipher_config_.cipher_time_window; @@ -228,6 +229,8 @@ void Server::InitIteration() { << " cipher_share_secrets_cnt_: " << cipher_share_secrets_cnt_; MS_LOG(INFO) << " cipher_get_secrets_cnt_: " << cipher_get_secrets_cnt_ << " cipher_get_clientlist_cnt_: " << cipher_get_clientlist_cnt_ + << " cipher_push_list_sign_cnt_: " << cipher_push_list_sign_cnt_ + << " cipher_get_list_sign_cnt_: " << cipher_get_list_sign_cnt_ << " cipher_reconstruct_secrets_up_cnt_: " << cipher_reconstruct_secrets_up_cnt_ << " cipher_reconstruct_secrets_down_cnt_: " << cipher_reconstruct_secrets_down_cnt_ << " cipher_time_window_: " << cipher_time_window_; @@ -245,6 +248,7 @@ void Server::InitIteration() { void Server::InitCipher() { #ifdef ENABLE_ARMOUR cipher_init_ = &armour::CipherInit::GetInstance(); + int cipher_t = SizeToInt(cipher_reconstruct_secrets_down_cnt_); unsigned char cipher_p[SECRET_MAX_LEN] = {0}; const int cipher_g = 1; @@ -258,8 +262,7 @@ void Server::InitCipher() { param.t = cipher_t; int ret = memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, sizeof(cipher_p)); if (ret != 0) { - MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; - return; + MS_LOG(EXCEPTION) << "Memcpy_s error, errorno" << ret; } param.dp_delta = dp_delta; param.dp_eps = dp_eps; @@ -281,10 +284,9 @@ void Server::InitCipher() { if (prim != NULL) { BN_clear_free(prim); } - - (void)cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_, - cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_, - cipher_reconstruct_secrets_up_cnt_); + cipher_init_->Init(param, 0, cipher_exchange_keys_cnt_, cipher_get_keys_cnt_, cipher_share_secrets_cnt_, + cipher_get_secrets_cnt_, cipher_get_clientlist_cnt_, cipher_push_list_sign_cnt_, + cipher_get_list_sign_cnt_, cipher_reconstruct_secrets_up_cnt_); #endif } @@ -366,12 +368,14 @@ void Server::RegisterMessageCallback(const std::shared_ptrRegisterMsgCallBack("queryInstance", std::bind(&Server::HandleQueryInstanceRequest, this, std::placeholders::_1)); + communicator->RegisterMsgCallBack("syncAfterRecover", + std::bind(&Server::HandleSyncAfterRecoveryRequest, this, std::placeholders::_1)); } void Server::InitExecutor() { MS_EXCEPTION_IF_NULL(func_graph_); if (executor_threshold_ == 0) { - MS_LOG(EXCEPTION) << "The executor's threshold should be greater than 0."; + MS_LOG(EXCEPTION) << "The executor's threshold should greater than 0."; return; } // The train engine instance is used in both push-type and pull-type kernels, @@ -430,7 +434,6 @@ void Server::StartCommunicator() { MS_LOG(INFO) << "Start communicator with server."; if (!communicator_with_server_->Start()) { MS_LOG(EXCEPTION) << "Starting communicator with server failed."; - return; } DistributedMetadataStore::GetInstance().Initialize(server_node_); CollectiveOpsImpl::GetInstance().Initialize(server_node_); @@ -447,6 +450,32 @@ void Server::StartCommunicator() { }); } +void Server::Recover() { + server_recovery_ = std::make_shared(); + MS_EXCEPTION_IF_NULL(server_recovery_); + + // Try to recovery from persistent storage. + if (!server_recovery_->Initialize(ps::PSContext::instance()->config_file_path())) { + MS_LOG(WARNING) << "Initializing server recovery failed. Do not recover for this server."; + return; + } + + if (server_recovery_->Recover()) { + // If this server recovers, need to notify cluster to reach consistency. + auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); + MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); + MS_LOG(INFO) << "Synchronize with leader server after recovery."; + if (!server_recovery_->SyncAfterRecovery(tcp_comm, server_node_->rank_id())) { + MS_LOG(EXCEPTION) << "Failed to reach consistency of the cluster after recovery."; + return; + } + } + + // Set the recovery handler to Iteration. + MS_EXCEPTION_IF_NULL(iteration_); + iteration_->set_recovery_handler(server_recovery_); +} + void Server::ProcessBeforeScalingOut() { MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); iteration_->ScalingBarrier(); @@ -510,7 +539,6 @@ void Server::ProcessAfterScalingIn() { } void Server::HandleEnableServerRequest(const std::shared_ptr &message) { - MS_ERROR_IF_NULL_WO_RET_VAL(message); MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); @@ -528,7 +556,6 @@ void Server::HandleEnableServerRequest(const std::shared_ptr &message) { - MS_ERROR_IF_NULL_WO_RET_VAL(message); MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); @@ -552,7 +579,6 @@ void Server::HandleNewInstanceRequest(const std::shared_ptr(communicator_with_server_); MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); - MS_ERROR_IF_NULL_WO_RET_VAL(message->data()); std::string hyper_params_str(static_cast(message->data()), message->len()); nlohmann::json new_instance_json; nlohmann::json response; @@ -595,6 +621,31 @@ void Server::HandleQueryInstanceRequest(const std::shared_ptr &message) { + MS_ERROR_IF_NULL_WO_RET_VAL(message); + MS_ERROR_IF_NULL_WO_RET_VAL(iteration_); + MS_ERROR_IF_NULL_WO_RET_VAL(communicator_with_server_); + auto tcp_comm = std::dynamic_pointer_cast(communicator_with_server_); + MS_ERROR_IF_NULL_WO_RET_VAL(tcp_comm); + + MS_LOG(INFO) << "Receive SyncAfterRecover request from other server."; + std::string response = "success"; + if (!tcp_comm->SendResponse(response.c_str(), response.size(), message)) { + MS_LOG(ERROR) << "Sending response of SyncAfterRecoverRequest failed."; + return; + } + + if (!safemode_.load()) { + MS_LOG(INFO) << "Need to synchronize for other server's recovery"; + SyncAfterRecover sync_after_recovery_req; + (void)sync_after_recovery_req.ParseFromArray(message->data(), SizeToInt(message->len())); + if (!iteration_->SyncAfterRecovery(sync_after_recovery_req.current_iter_num())) { + MS_LOG(ERROR) << "Sync after recovery failed."; + return; + } + } +} } // namespace server } // namespace fl } // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/server.h b/mindspore/ccsrc/fl/server/server.h index 4d5109e6348..26d77aebdd7 100644 --- a/mindspore/ccsrc/fl/server/server.h +++ b/mindspore/ccsrc/fl/server/server.h @@ -24,19 +24,19 @@ #include "ps/core/communicator/tcp_communicator.h" #include "ps/core/communicator/task_executor.h" #include "ps/core/file_configuration.h" -#include "fl/server/common.h" -#include "fl/server/executor.h" -#include "fl/server/iteration.h" #ifdef ENABLE_ARMOUR #include "fl/armour/cipher/cipher_init.h" #endif +#include "fl/server/common.h" +#include "fl/server/executor.h" +#include "fl/server/iteration.h" namespace mindspore { namespace fl { namespace server { // The sleeping time of the server thread before the networking is completed. constexpr uint32_t kServerSleepTimeForNetworking = 1000; - +constexpr uint64_t kDefaultReplayAttackTimeDiff = 60000; // Class Server is the entrance of MindSpore's parameter server training mode and federated learning. class Server { public: @@ -51,6 +51,21 @@ class Server { // According to the current MindSpore framework, method Run is a step of the server pipeline. This method will be // blocked until the server is finalized. // func_graph is the frontend graph which will be parse in server's exector and aggregator. + + // Each step of the server pipeline may have dependency on other steps, which includes: + // InitServerContext must be the first step to set contexts for later steps. + + // Server Running relies on URL or Message Type Register: + // StartCommunicator---->InitIteration + + // Metadata Register relies on Hash Ring of Servers which relies on Network Building Completion: + // RegisterRoundKernel---->StartCommunicator + + // Kernel Initialization relies on Executor Initialization: + // RegisterRoundKernel---->InitExecutor + + // Getting Model Size relies on ModelStorage Initialization which relies on Executor Initialization: + // InitCipher---->InitExecutor void Run(); void SwitchToSafeMode(); @@ -74,17 +89,25 @@ class Server { communicators_with_worker_({}), iteration_(nullptr), safemode_(true), + server_recovery_(nullptr), scheduler_ip_(""), scheduler_port_(0), server_num_(0), worker_num_(0), fl_server_port_(0), + pki_verify_(false), + root_first_ca_path_(""), + root_second_ca_path_(""), + equip_crl_path_(""), + replay_attack_time_diff_(kDefaultReplayAttackTimeDiff), cipher_initial_client_cnt_(0), cipher_exchange_keys_cnt_(0), cipher_get_keys_cnt_(0), cipher_share_secrets_cnt_(0), cipher_get_secrets_cnt_(0), cipher_get_clientlist_cnt_(0), + cipher_push_list_sign_cnt_(0), + cipher_get_list_sign_cnt_(0), cipher_reconstruct_secrets_up_cnt_(0), cipher_reconstruct_secrets_down_cnt_(0), cipher_time_window_(0) {} @@ -95,9 +118,6 @@ class Server { // Load variables which is set by ps_context. void InitServerContext(); - // Try to recover server config from persistent storage. - void Recovery(); - // Initialize the server cluster, server node and communicators. void InitCluster(); bool InitCommunicatorWithServer(); @@ -130,6 +150,12 @@ class Server { // The communicators should be started after all initializations are completed. void StartCommunicator(); + // Try to recover server config from persistent storage. + void Recover(); + + // load pki huks cbg root certificate and crl + void InitPkiCertificate(); + // The barriers before scaling operations. void ProcessBeforeScalingOut(); void ProcessBeforeScalingIn(); @@ -148,6 +174,9 @@ class Server { // Query current instance information. void HandleQueryInstanceRequest(const std::shared_ptr &message); + // Synchronize after recovery is completed to ensure consistency. + void HandleSyncAfterRecoveryRequest(const std::shared_ptr &message); + // The server node is initialized in Server. std::shared_ptr server_node_; @@ -179,7 +208,7 @@ class Server { // communicators. std::vector> communicators_with_worker_; - // Mutex for scaling operations. We must wait server's initialization done before handle scaling events. + // Mutex for scaling operations. std::mutex scaling_mtx_; // Iteration consists of multiple kinds of rounds. @@ -189,21 +218,33 @@ class Server { // If true, the server is not available to workers and clients. std::atomic_bool safemode_; + // The recovery object for server. + std::shared_ptr server_recovery_; + // Variables set by ps context. #ifdef ENABLE_ARMOUR - armour::CipherInit *cipher_init_{nullptr}; + armour::CipherInit *cipher_init_; #endif std::string scheduler_ip_; uint16_t scheduler_port_; uint32_t server_num_; uint32_t worker_num_; uint16_t fl_server_port_; + bool pki_verify_; + + std::string root_first_ca_path_; + std::string root_second_ca_path_; + std::string equip_crl_path_; + uint64_t replay_attack_time_diff_; + size_t cipher_initial_client_cnt_; size_t cipher_exchange_keys_cnt_; size_t cipher_get_keys_cnt_; size_t cipher_share_secrets_cnt_; size_t cipher_get_secrets_cnt_; size_t cipher_get_clientlist_cnt_; + size_t cipher_push_list_sign_cnt_; + size_t cipher_get_list_sign_cnt_; size_t cipher_reconstruct_secrets_up_cnt_; size_t cipher_reconstruct_secrets_down_cnt_; uint64_t cipher_time_window_; diff --git a/mindspore/ccsrc/fl/server/server_recovery.cc b/mindspore/ccsrc/fl/server/server_recovery.cc new file mode 100644 index 00000000000..0c89d30749d --- /dev/null +++ b/mindspore/ccsrc/fl/server/server_recovery.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fl/server/server_recovery.h" +#include "fl/server/local_meta_store.h" +#include "debug/common.h" + +namespace mindspore { +namespace fl { +namespace server { +bool ServerRecovery::Initialize(const std::string &config_file) { + config_ = std::make_unique(config_file); + MS_EXCEPTION_IF_NULL(config_); + if (!config_->Initialize()) { + MS_LOG(EXCEPTION) << "Initializing for server recovery failed. Config file path " << config_file + << " may be invalid or not exist."; + return false; + } + + // Read the server recovery file path. + if (!config_->Exists(kServerRecovery)) { + MS_LOG(WARNING) << "Server recovery config is not set. This node doesn't support recovery."; + return true; + } else { + std::string value = config_->Get(kServerRecovery, ""); + nlohmann::json value_json; + try { + value_json = nlohmann::json::parse(value); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "The data is not in json format."; + return false; + } + + // Parse the storage type. + uint32_t storage_type = JsonGetKeyWithException(value_json, ps::kStoreType); + if (std::to_string(storage_type) != ps::kFileStorage) { + MS_LOG(EXCEPTION) << "Storage type " << storage_type << " is not supported."; + return false; + } + + // Parse storage file path. + server_recovery_file_path_ = JsonGetKeyWithException(value_json, ps::kStoreFilePath); + MS_LOG(INFO) << "Server recovery file path is " << server_recovery_file_path_; + } + return true; +} + +bool ServerRecovery::Recover() { + server_recovery_file_.open(server_recovery_file_path_, std::ios::in); + if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) { + MS_LOG(WARNING) << "Can't open server recovery file " << server_recovery_file_path_; + return false; + } + + nlohmann::json server_recovery_json; + try { + server_recovery_json = nlohmann::json::parse(server_recovery_file_); + } catch (const std::exception &e) { + MS_LOG(EXCEPTION) << "The server recovery file is not in json format."; + return false; + } + uint64_t current_iter = JsonGetKeyWithException(server_recovery_json, kCurrentIteration); + LocalMetaStore::GetInstance().set_curr_iter_num(current_iter); + MS_LOG(INFO) << "Recover from persistent storage: current iteration number is " << current_iter; + server_recovery_file_.close(); + return true; +} + +bool ServerRecovery::Save(uint64_t current_iter) { + server_recovery_file_.open(server_recovery_file_path_, std::ios::out | std::ios::ate); + if (!server_recovery_file_.good() || !server_recovery_file_.is_open()) { + MS_LOG(WARNING) << "Can't save data to recovery file " << server_recovery_file_path_ + << ". This file path is invalid or does not exit."; + return false; + } + + nlohmann::json server_metadata_json; + server_metadata_json[kCurrentIteration] = current_iter; + server_recovery_file_ << server_metadata_json; + server_recovery_file_.close(); + return true; +} + +bool ServerRecovery::SyncAfterRecovery(const std::shared_ptr &communicator, + uint32_t rank_id) { + // If this server is follower server, notify leader server that this server has recovered. + if (rank_id != kLeaderServerRank) { + MS_ERROR_IF_NULL_W_RET_VAL(communicator, false); + SyncAfterRecover sync_after_recover_req; + sync_after_recover_req.set_current_iter_num(LocalMetaStore::GetInstance().curr_iter_num()); + if (!communicator->SendPbRequest(sync_after_recover_req, kLeaderServerRank, + ps::core::TcpUserCommand::kSyncAfterRecover)) { + MS_LOG(ERROR) << "Sending sync after recovery message to leader server failed."; + return false; + } + } + return true; +} +} // namespace server +} // namespace fl +} // namespace mindspore diff --git a/mindspore/ccsrc/fl/server/server_recovery.h b/mindspore/ccsrc/fl/server/server_recovery.h new file mode 100644 index 00000000000..5e41aa41bb1 --- /dev/null +++ b/mindspore/ccsrc/fl/server/server_recovery.h @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FL_SERVER_SERVER_RECOVERY_H_ +#define MINDSPORE_CCSRC_FL_SERVER_SERVER_RECOVERY_H_ + +#include +#include +#include +#include +#include +#include "ps/core/recovery_base.h" +#include "ps/core/file_configuration.h" +#include "ps/core/communicator/tcp_communicator.h" +#include "ps/ps_context.h" + +namespace mindspore { +namespace fl { +namespace server { +constexpr auto kServerRecovery = "server_recovery"; + +// The class helps server node to do recovery operation. +// Different from the recovery process in ps/core/node_recovery.*, this class focus on recovery of the server data. For +// example, current iteration number, learning rate, etc. +class ServerRecovery : public ps::core::RecoveryBase { + public: + ServerRecovery() : config_(nullptr), server_recovery_file_path_("") {} + ~ServerRecovery() override = default; + + bool Initialize(const std::string &config_file) override; + bool Recover() override; + + // Save server's metadata to persistent storage. + bool Save(uint64_t current_iter); + + // If this server recovers, need to notify cluster to reach consistency. + bool SyncAfterRecovery(const std::shared_ptr &communicator, uint32_t rank_id); + + private: + // This is the main config file set by ps context. + std::unique_ptr config_; + + // The server recovery file path. + std::string server_recovery_file_path_; + + // The server recovery file object. + std::fstream server_recovery_file_; +}; +} // namespace server +} // namespace fl +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FL_SERVER_SERVER_RECOVERY_H_ diff --git a/mindspore/ccsrc/fl/worker/fl_worker.cc b/mindspore/ccsrc/fl/worker/fl_worker.cc index ae385b23251..4ae6d8479b4 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.cc +++ b/mindspore/ccsrc/fl/worker/fl_worker.cc @@ -220,9 +220,9 @@ void FLWorker::InitializeFollowerScaler() { std::bind(&FLWorker::ProcessAfterScalingOut, this)); worker_node_->RegisterFollowerScalerHandlerAfterScaleIn("WorkerPipeline", std::bind(&FLWorker::ProcessAfterScalingIn, this)); - worker_node_->RegisterCustomEventCallback(static_cast(ps::CustomEvent::kIterationRunning), + worker_node_->RegisterCustomEventCallback(static_cast(ps::UserDefineEvent::kIterationRunning), std::bind(&FLWorker::HandleIterationRunningEvent, this)); - worker_node_->RegisterCustomEventCallback(static_cast(ps::CustomEvent::kIterationCompleted), + worker_node_->RegisterCustomEventCallback(static_cast(ps::UserDefineEvent::kIterationCompleted), std::bind(&FLWorker::HandleIterationCompletedEvent, this)); } diff --git a/mindspore/ccsrc/fl/worker/fl_worker.h b/mindspore/ccsrc/fl/worker/fl_worker.h index 912b0b97b59..db9f3e2da7b 100644 --- a/mindspore/ccsrc/fl/worker/fl_worker.h +++ b/mindspore/ccsrc/fl/worker/fl_worker.h @@ -51,7 +51,7 @@ constexpr uint32_t kWorkerSleepTimeForNetworking = 1000; // The time duration between retrying when server is in safemode. constexpr uint32_t kWorkerRetryDurationForSafeMode = 500; -// The rank of the leader server. +// The leader server rank. constexpr uint32_t kLeaderServerRank = 0; // The timeout for worker sending message to server in case of network jitter. diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index bcdd89b04dd..3e118c846c6 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -879,6 +879,10 @@ bool StartServerAction(const ResourcePtr &res) { std::max(static_cast(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold); size_t client_list_threshold = std::max(static_cast(std::ceil(update_model_threshold * share_secrets_ratio)), reconstruct_secrets_threshold); + size_t push_list_sign_threshold = std::max( + static_cast(std::ceil(client_list_threshold * share_secrets_ratio)), reconstruct_secrets_threshold); + size_t get_list_sign_threshold = std::max( + static_cast(std::ceil(push_list_sign_threshold * share_secrets_ratio)), reconstruct_secrets_threshold); #ifdef ENABLE_ARMOUR std::string encrypt_type = ps::PSContext::instance()->encrypt_type(); if (encrypt_type == ps::kPWEncryptType) { @@ -889,6 +893,10 @@ bool StartServerAction(const ResourcePtr &res) { rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold}); rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold}); rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold}); + if (ps::PSContext::instance()->pki_verify()) { + rounds_config.push_back({"pushListSign", true, cipher_time_window, true, push_list_sign_threshold}); + rounds_config.push_back({"getListSign", true, cipher_time_window, true, get_list_sign_threshold}); + } } if (encrypt_type == ps::kStablePWEncryptType) { MS_LOG(INFO) << "Add stable secure aggregation rounds."; @@ -897,8 +905,9 @@ bool StartServerAction(const ResourcePtr &res) { } #endif fl::server::CipherConfig cipher_config = { - share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold, - share_secrets_threshold, get_secrets_threshold, client_list_threshold, reconstruct_secrets_threshold}; + share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold, + share_secrets_threshold, get_secrets_threshold, client_list_threshold, push_list_sign_threshold, + get_list_sign_threshold, reconstruct_secrets_threshold}; size_t executor_threshold = 0; if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index e71aa4b76d7..086f1470450 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -403,6 +403,7 @@ PYBIND11_MODULE(_c_expression, m) { "Set threshold count ratio for share secrets round.") .def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.") .def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.") + .def("cipher_time_window", &PSContext::cipher_time_window, "Get time window for cipher rounds.") .def("set_reconstruct_secrets_threshold", &PSContext::set_reconstruct_secrets_threshold, "Set threshold count for reconstruct secrets round.") .def("reconstruct_secrets_threshold", &PSContext::reconstruct_secrets_threshold, @@ -416,16 +417,38 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.") .def("client_batch_size", &PSContext::client_batch_size, "Get federated learning client batch size.") .def("set_client_learning_rate", &PSContext::set_client_learning_rate, - "Set worker's standalone training step number before communicating with server.") + "Set federated learning client learning rate.") .def("client_learning_rate", &PSContext::client_learning_rate, "Get worker's standalone training step number before communicating with server.") .def("set_worker_step_num_per_iteration", &PSContext::set_worker_step_num_per_iteration, - "Set federated learning client learning rate.") + "Set worker's standalone training step number before communicating with server..") .def("worker_step_num_per_iteration", &PSContext::worker_step_num_per_iteration, "Get federated learning client learning rate.") + .def("set_secure_aggregation", &PSContext::set_secure_aggregation, + "Set federated learning client using secure aggregation.") + .def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.") + .def("dp_eps", &PSContext::dp_eps, "Get dp epsilon for federated learning secure aggregation.") + .def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.") + .def("dp_delta", &PSContext::dp_delta, "Get dp delta for federated learning secure aggregation.") + .def("set_dp_norm_clip", &PSContext::set_dp_norm_clip, + "Set dp norm clip for federated learning secure aggregation.") + .def("dp_norm_clip", &PSContext::dp_norm_clip, "Get dp norm clip for federated learning secure aggregation.") + .def("set_encrypt_type", &PSContext::set_encrypt_type, + "Set encrypt type for federated learning secure aggregation.") + .def("encrypt_type", &PSContext::encrypt_type, "Get encrypt type for federated learning secure aggregation.") + .def("set_root_first_ca_path", &PSContext::set_root_first_ca_path, "Set root first ca path.") + .def("root_first_ca_path", &PSContext::root_first_ca_path, "Get root first ca path.") + .def("set_root_second_ca_path", &PSContext::set_root_second_ca_path, "Set root second ca path.") + .def("root_second_ca_path", &PSContext::root_second_ca_path, "Get root second ca path.") + .def("set_pki_verify", &PSContext::set_pki_verify, "Set pki verify.") + .def("pki_verify", &PSContext::pki_verify, "Get pki verify.") .def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port, "Set scheduler manage port used to scale out/in.") .def("scheduler_manage_port", &PSContext::scheduler_manage_port, "Get scheduler manage port used to scale out/in.") + .def("set_equip_crl_path", &PSContext::set_equip_crl_path, "Set root second crl path.") + .def("set_replay_attack_time_diff", &PSContext::set_replay_attack_time_diff, "Set replay attack time diff.") + .def("equip_crl_path", &PSContext::equip_crl_path, "Get root second crl path.") + .def("replay_attack_time_diff", &PSContext::replay_attack_time_diff, "Get replay attack time diff.") .def("set_enable_ssl", &PSContext::set_enable_ssl, "Set PS SSL mode enabled or disabled.") .def("enable_ssl", &PSContext::enable_ssl, "Get PS SSL mode enabled or disabled.") .def("set_client_password", &PSContext::set_client_password, "Set the client password to decode the p12 file.") @@ -441,7 +464,9 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_dp_norm_clip", &PSContext::set_dp_norm_clip, "Set dp norm clip for federated learning secure aggregation.") .def("set_encrypt_type", &PSContext::set_encrypt_type, - "Set encrypt type for federated learning secure aggregation."); + "Set encrypt type for federated learning secure aggregation.") + .def("set_http_url_prefix", &PSContext::set_http_url_prefix, "Set http url prefix for http communication.") + .def("http_url_prefix", &PSContext::http_url_prefix, "http url prefix for http communication."); (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data."); (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data."); diff --git a/mindspore/ccsrc/ps/constants.h b/mindspore/ccsrc/ps/constants.h index 0091ca8cc2e..5731cfcefd9 100644 --- a/mindspore/ccsrc/ps/constants.h +++ b/mindspore/ccsrc/ps/constants.h @@ -76,28 +76,15 @@ constexpr int64_t kPullCmd = 51; constexpr size_t kInvalidKey = UINT64_MAX; constexpr int64_t kInvalidID = -1; -constexpr int64_t kGradIndex = 0; -constexpr int64_t kIndiceIndex = 1; -constexpr int64_t kFirstDimSize = 2; -constexpr int64_t kOutDimSize = 3; - -constexpr int64_t kBase = 10; -constexpr float kStdDev = 0.01; - -constexpr int64_t kSparseLazyAdamIndex = 2; -constexpr int64_t kSparseFtrlIndex = 3; -constexpr int64_t kSparseGradIndex = 6; -constexpr int64_t kSparseIndiceIndex = 7; - -constexpr int64_t kHeartbeatTimes = 2; -constexpr int64_t kGradValue = -100; - constexpr uint32_t kMaxMessageSize = static_cast(100 * (uint32_t(1) << 20)); constexpr char kServerNum[] = "server_num"; constexpr char kWorkerNum[] = "worker_num"; constexpr char kNodesIds[] = "node_ids"; constexpr char kNodeId[] = "node_id"; +constexpr char kSuccessCode[] = "0"; +constexpr char kErrorCode[] = "1"; + constexpr int64_t kSubmitTaskIntervalInMs = 1; constexpr int64_t kMaxTaskNum = 10240; constexpr int64_t kSubmitTimeOutInMs = 30000; @@ -105,7 +92,13 @@ constexpr int64_t kRetryCount = 60; constexpr int64_t kRetryIntervalInMs = 10; constexpr int64_t kThreadNum = 32; +constexpr int64_t kGradIndex = 0; +constexpr int64_t kIndiceIndex = 1; +constexpr int64_t kFirstDimSize = 2; +constexpr int64_t kOutDimSize = 3; +constexpr int64_t kBase = 10; +constexpr float kStdDev = 0.01; // The timeout period for the scale in node to send the finish message to scheduler. constexpr uint32_t kScaleInTimeoutInSenconds = 30; // The number of retries to determine whether all nodes are successfully registered. @@ -113,10 +106,26 @@ constexpr uint32_t kCheckRegisteredRetryCount = 30; // The timeout interval for judging whether all nodes are successfully registered. constexpr uint32_t kCheckRegisteredIntervalInMs = 1000; +// The barrier function which should be called before doing scaling out/in operations. +// It's easy for us to scale out/in nodes after one iteration is completed and keep consistent. +using BarrierBeforeScaleOut = std::function; +using BarrierBeforeScaleIn = std::function; + +constexpr int64_t kSparseLazyAdamIndex = 2; +constexpr int64_t kSparseFtrlIndex = 3; +constexpr int64_t kSparseGradIndex = 6; +constexpr int64_t kSparseIndiceIndex = 7; + +constexpr int64_t kHeartbeatTimes = 2; +constexpr int64_t kGradValue = -100; +// Whether to support recovery. +constexpr char kIsRecovery[] = "is_recovery"; // The type of persistent storage, currently only supports file storage. constexpr char kStoreType[] = "storage_type"; // The file used to storage metadata. constexpr char kStoreFilePath[] = "storage_file_path"; +// The file used to storage scheduler metadata. +constexpr char kSchedulerStoreFilePath[] = "scheduler_storage_file_path"; // 1 indicates that the persistent storage type is file. constexpr char kFileStorage[] = "1"; // The recovery key of json_config. @@ -125,6 +134,10 @@ constexpr char kRecoveryWorkerNum[] = "worker_num"; constexpr char kRecoveryServerNum[] = "server_num"; constexpr char kRecoverySchedulerIp[] = "scheduler_ip"; constexpr char kRecoverySchedulerPort[] = "scheduler_port"; +constexpr char kRecoveryTotalNodeNum[] = "total_node_num"; +constexpr char kRecoveryNextWorkerRankId[] = "next_worker_rank_id"; +constexpr char kRecoveryNextServerRankId[] = "next_server_rank_id"; +constexpr char kRecoveryRegisteredNodesInfos[] = "node_ids"; constexpr char kServerCertPath[] = "server_cert_path"; constexpr char kServerPassword[] = "server_password"; @@ -132,7 +145,6 @@ constexpr char kCrlPath[] = "crl_path"; constexpr char kClientCertPath[] = "client_cert_path"; constexpr char kClientPassword[] = "client_password"; constexpr char kCaCertPath[] = "ca_cert_path"; - constexpr char kCipherList[] = "cipher_list"; constexpr char kCertCheckInterval[] = "cert_check_interval_in_hour"; // 7 * 24 @@ -154,6 +166,7 @@ constexpr int64_t kMaxWarningTime = 180; constexpr int64_t kLength = 100; constexpr int64_t kMaxPort = 65535; +constexpr int64_t kSecurityLevel = 3; constexpr char kTcpCommunicator[] = "TCP"; constexpr char kHttpCommunicator[] = "HTTP"; @@ -162,35 +175,14 @@ constexpr char kServerCert[] = "server.p12"; constexpr char kClientCert[] = "client.p12"; constexpr char kCaCert[] = "ca.crt"; constexpr char kColon = ':'; -const std::map kCiphers = {{"ECDHE-RSA-AES128-GCM-SHA256", 0}, - {"ECDHE-ECDSA-AES128-GCM-SHA256", 1}, - {"ECDHE-RSA-AES256-GCM-SHA384", 2}, - {"ECDHE-ECDSA-AES256-GCM-SHA384", 3}, - {"DHE-RSA-AES128-GCM-SHA256", 4}, - {"DHE-DSS-AES128-GCM-SHA256", 5}, - {"ECDHE-RSA-AES128-SHA256", 6}, - {"ECDHE-ECDSA-AES128-SHA256", 7}, - {"ECDHE-RSA-AES128-SHA", 8}, - {"ECDHE-ECDSA-AES128-SHA", 9}, - {"ECDHE-RSA-AES256-SHA384", 10}, - {"ECDHE-ECDSA-AES256-SHA384", 11}, - {"ECDHE-RSA-AES256-SHA", 12}, - {"ECDHE-ECDSA-AES256-SHA", 13}, - {"DHE-RSA-AES128-SHA256", 14}, - {"DHE-RSA-AES128-SHA", 15}, - {"DHE-DSS-AES128-SHA256", 16}, - {"DHE-RSA-AES256-SHA256", 17}, - {"DHE-DSS-AES256-SHA", 18}, - {"DHE-RSA-AES256-SHA", 19}, - {"!aNULL", 20}, - {"!eNULL", 21}, - {"!EXPORT", 22}, - {"!DES", 23}, - {"!RC4", 24}, - {"!3DES", 25}, - {"!MD5", 26}, - {"!PSK", 27}, - {"kEDH+AESGCM", 28}}; +const std::map kCiphers = { + {"ECDHE-RSA-AES128-GCM-SHA256", 0}, {"ECDHE-ECDSA-AES128-GCM-SHA256", 1}, {"ECDHE-RSA-AES256-GCM-SHA384", 2}, + {"ECDHE-ECDSA-AES256-GCM-SHA384", 3}, {"DHE-RSA-AES128-GCM-SHA256", 4}, {"DHE-DSS-AES128-GCM-SHA256", 5}, + {"DHE-RSA-AES256-GCM-SHA384", 6}, {"DHE-DSS-AES256-GCM-SHA384", 7}, {"DHE-PSK-AES128-GCM-SHA256", 8}, + {"DHE-PSK-AES256-GCM-SHA384", 9}, {"DHE-PSK-CHACHA20-POLY1305", 10}, {"ECDHE-RSA-CHACHA20-POLY1305", 11}, + {"ECDHE-PSK-CHACHA20-POLY1305", 12}, {"DHE-RSA-AES128-CCM", 13}, {"DHE-RSA-AES256-CCM", 14}, + {"DHE-RSA-CHACHA20-POLY1305", 15}, {"DHE-PSK-AES128-CCM", 16}, {"DHE-PSK-AES256-CCM", 17}, + {"ECDHE-ECDSA-AES128-CCM", 18}, {"ECDHE-ECDSA-AES256-CCM", 19}, {"ECDHE-ECDSA-CHACHA20-POLY1305", 20}}; #ifdef __APPLE__ using DataPtr = std::shared_ptr; @@ -262,7 +254,7 @@ using HandlerAfterScaleIn = std::function; constexpr char kClusterSafeMode[] = "The cluster is in safemode."; constexpr char kJobNotAvailable[] = "The server's training job is disabled or finished."; -enum class CustomEvent { kIterationRunning = 0, kIterationCompleted }; +enum class UserDefineEvent { kIterationRunning = 0, kIterationCompleted, kNodeTimeout }; #define EXC_IF_VEC_IDX_OOB(vec, idx) \ { \ diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 00093010b35..07593d045d6 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,9 +50,18 @@ void AbstractNode::ProcessRegisterResp(const std::shared_ptr &meta, MS_EXCEPTION_IF_NULL(data); RegisterRespMessage register_resp_message; CHECK_RETURN_TYPE(register_resp_message.ParseFromArray(data, SizeToInt(size))); + MS_LOG(INFO) << "The node id get from scheduler is:" << register_resp_message.node_id() + << ", rank_id is:" << register_resp_message.rank_id(); + if (register_resp_message.node_id() != node_info_.node_id_) { - MS_LOG(EXCEPTION) << "The node id received:" << register_resp_message.node_id() - << " is not match the current node id:" << node_info_.node_id_; + MS_LOG(ERROR) << "The node id received:" << register_resp_message.node_id() + << " is not match the current node id:" << node_info_.node_id_; + return; + } + node_info_.rank_id_ = register_resp_message.rank_id(); + if (node_info_.rank_id_ == UINT32_MAX) { + MS_LOG(ERROR) << "The rank id received:" << register_resp_message.rank_id(); + return; } // Receive the Register message, indicating that the scheduler is alive, so update the time point at which the @@ -70,7 +79,7 @@ bool AbstractNode::Broadcast(const NodeRole &node_role, const DataPtr &message, } uint32_t broadcast_size = 0; - std::for_each(nodes_address_.begin(), nodes_address_.end(), [&broadcast_size, &node_role](const auto &addr) { + (void)std::for_each(nodes_address_.begin(), nodes_address_.end(), [&broadcast_size, &node_role](const auto &addr) { if (addr.first.first == node_role) { ++broadcast_size; } @@ -160,19 +169,24 @@ void AbstractNode::BroadcastEvent(const uint32_t &event) { MS_EXCEPTION_IF_NULL(message_meta); message_meta->set_cmd(NodeCommand::SEND_EVENT); - EventMessage event_message; - event_message.set_event(event); - event_message.set_node_id(node_info_.node_id_); + EventRespMessage event_resp_message; + event_resp_message.set_event(event); - if (!SendMessageSync(client_to_scheduler_, message_meta, Protos::PROTOBUF, event_message.SerializeAsString().data(), - event_message.ByteSizeLong())) { - MS_LOG(ERROR) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) - << " the node id:" << node_info_.node_id_ << " send event timeout!"; - return; + for (auto it = nodes_address_.begin(); it != nodes_address_.end(); ++it) { + const uint32_t rank_id = (*it).first.second; + const NodeRole role = (*it).first.first; + auto client = GetOrCreateTcpClient(rank_id, role); + if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, event_resp_message.SerializeAsString().data(), + event_resp_message.ByteSizeLong())) { + MS_LOG(ERROR) << "send event to node role:" << CommUtil::NodeRoleToString(role) << ", rank id:" << rank_id + << " timeout!"; + } else { + MS_LOG(INFO) << "send event to node role:" << CommUtil::NodeRoleToString(role) << ", rank id:" << rank_id + << " successful!"; + } } - MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) - << " the node id:" << node_info_.node_id_ << "is send event to scheduler!"; + << " the node id:" << node_info_.node_id_ << "is send event to server/worker!"; } void AbstractNode::RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb) { @@ -185,10 +199,6 @@ void AbstractNode::RegisterCustomEventCallback(const uint32_t &event, const Even bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command, const uint32_t &timeout) { - if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) { - MS_LOG(DEBUG) << "The node is timeout, can not send message."; - return false; - } MS_EXCEPTION_IF_NULL(data); if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) { MS_LOG(EXCEPTION) << "The node role or rank_id is illegal, the worker num:" << worker_num_ @@ -209,11 +219,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, cons bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, const std::vector &lens, int command, const uint32_t &timeout) { - if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) { - MS_LOG(DEBUG) << "The node is timeout, can not send message."; - return false; - } - uint64_t request_id = AddMessageTrack(data.size()); if (rank_ids.size() != data.size() || rank_ids.size() != lens.size()) { @@ -248,10 +253,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const std::vector & bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command, VectorPtr *output, const uint32_t &timeout) { - if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) { - MS_LOG(DEBUG) << "The node is timeout, can not send message."; - return false; - } MS_EXCEPTION_IF_NULL(message); MS_EXCEPTION_IF_NULL(output); if (!CommUtil::ValidateRankId(node_role, rank_id, worker_num_, server_num_)) { @@ -289,10 +290,6 @@ bool AbstractNode::Send(const NodeRole &node_role, const uint32_t &rank_id, cons bool AbstractNode::Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, const std::vector &data_lens, int command, std::vector *output, const uint32_t &timeout) { - if (current_cluster_state_ == ClusterState::NODE_TIMEOUT) { - MS_LOG(DEBUG) << "The node is timeout, can not send message."; - return false; - } MS_EXCEPTION_IF_NULL(output); uint64_t request_id = AddMessageTrack(data.size()); @@ -493,10 +490,6 @@ std::shared_ptr AbstractNode::GetOrCreateTcpComm(const std::st if (!communicators_.count(kTcpCommunicator)) { MS_LOG(INFO) << "Create Tcp communicator."; auto tcp_comm = std::make_shared(task_executor, this); - PSContext::instance()->cluster_config().scheduler_host = scheduler_ip; - PSContext::instance()->cluster_config().scheduler_port = static_cast(scheduler_port); - PSContext::instance()->cluster_config().initial_worker_num = worker_num; - PSContext::instance()->cluster_config().initial_server_num = server_num; MS_EXCEPTION_IF_NULL(tcp_comm); PSContext::instance()->cluster_config().scheduler_host = scheduler_ip; PSContext::instance()->cluster_config().scheduler_port = static_cast(scheduler_port); @@ -521,13 +514,7 @@ void AbstractNode::StartHeartbeatTimer(const std::shared_ptr &client) MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) << ", the node id is:" << node_info_.node_id_ << " Send heartbeat timeout!"; if (CheckSchedulerTimeout()) { - MS_LOG(WARNING) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) - << ", the node id is:" << node_info_.node_id_ << " exited due to scheduler timeout!"; - is_finish_ = true; - wait_finish_cond_.notify_all(); - if (!is_already_stopped_) { - OnEventCallback(ClusterEvent::SCHEDULER_TIMEOUT); - } + MS_LOG(WARNING) << "Scheduler is Timeout, please recovery."; } } else { UpdateSchedulerTime(); @@ -549,6 +536,7 @@ bool AbstractNode::Heartbeat(const std::shared_ptr &client) { HeartbeatMessage heartbeat_message; heartbeat_message.set_node_id(node_info_.node_id_); + MS_LOG(DEBUG) << "The node id:" << node_info_.node_id_ << " Send heartbeat!"; if (!SendMessageSync(client, meta, Protos::PROTOBUF, heartbeat_message.SerializeAsString().data(), heartbeat_message.ByteSizeLong(), kCommTimeoutInSeconds)) { MS_LOG(WARNING) << "The node id:" << node_info_.node_id_ << " Send heartbeat timeout!"; @@ -580,20 +568,31 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr &meta HeartbeatRespMessage heartbeat_resp_message; CHECK_RETURN_TYPE(heartbeat_resp_message.ParseFromArray(data, SizeToInt(size))); + if (heartbeat_resp_message.cluster_state() != current_cluster_state_) { + MS_LOG(INFO) << "cluster change state from:" << CommUtil::ClusterStateToString(current_cluster_state_) << " to " + << CommUtil::ClusterStateToString(heartbeat_resp_message.cluster_state()); + } + current_cluster_state_ = heartbeat_resp_message.cluster_state(); MS_LOG(DEBUG) << "The current cluster state from heartbeat:" << CommUtil::ClusterStateToString(current_cluster_state_); + std::string timeoutNodeId; + all_nodes_info_.clear(); for (const auto &it : heartbeat_resp_message.servers_meta()) { NodeInfo info; info.ip_ = it.ip(); info.node_id_ = it.node_id(); - info.port_ = static_cast(it.port()); + info.port_ = it.port(); info.node_role_ = it.role(); info.rank_id_ = it.rank_id(); info.is_alive = it.is_alive(); + if (!info.is_alive) { + timeoutNodeId += (info.node_id_ + " "); + } + all_nodes_info_[info.node_id_] = info; MS_LOG(DEBUG) << "The node id:" << info.node_id_ << ", the rank id:" << info.rank_id_ << ", the node role:" << CommUtil::NodeRoleToString(info.node_role_) << " is alive:" << info.is_alive; @@ -608,7 +607,8 @@ void AbstractNode::ProcessHeartbeatResp(const std::shared_ptr &meta wait_start_cond_.notify_all(); OnEventCallback(ClusterEvent::NODE_TIMEOUT); } else { - MS_LOG(INFO) << "The node is support recovery, users can pull up this node to restore the cluster."; + MS_LOG(INFO) << "The nodes:" << timeoutNodeId + << "is support recovery, users can pull up this node to restore the cluster."; } } } @@ -653,14 +653,14 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr &con return; } SendMetadataMessage send_meta_message; - CHECK_RETURN_TYPE(send_meta_message.ParseFromArray(data, SizeToInt(size))); + send_meta_message.ParseFromArray(data, SizeToInt(size)); worker_num_ = send_meta_message.worker_num(); server_num_ = send_meta_message.server_num(); if (send_meta_message.rank_id() < 0) { MS_LOG(EXCEPTION) << "The rank id is wrong."; } node_info_.rank_id_ = send_meta_message.rank_id(); - current_cluster_state_ = send_meta_message.cluster_state(); + UpdateClusterState(send_meta_message.cluster_state()); MS_LOG(INFO) << "The send metadata worker num:" << worker_num_ << ", server num:" << server_num_ << ", cluster state is:" << CommUtil::ClusterStateToString(current_cluster_state_) << ", the rank id:" << node_info_.rank_id_; @@ -669,7 +669,8 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr &con nodes_address_.clear(); for (const auto &it : send_meta_message.servers_meta()) { nodes_address_[std::make_pair(it.role(), it.rank_id())] = std::make_pair(it.ip(), it.port()); - MS_LOG(INFO) << "The server ip is:" << it.ip() << ", the port is:" << it.port() << ", the rank id:" << it.rank_id(); + MS_LOG(INFO) << "The node role:" << CommUtil::NodeRoleToString(it.role()) << ", node id:" << it.node_id() + << ", rank id:" << it.rank_id() << ", ip:" << it.ip() << ", port:" << it.port(); } client_mutex_.unlock(); if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) { @@ -690,6 +691,7 @@ void AbstractNode::ProcessSendMetadata(const std::shared_ptr &con std::lock_guard lock(client_mutex_); connected_nodes_.clear(); + PersistMetaData(); } void AbstractNode::ProcessFinish(const std::shared_ptr &conn, const std::shared_ptr &meta, @@ -710,11 +712,13 @@ void AbstractNode::ProcessScaleOutDone(const std::shared_ptr &con MS_EXCEPTION_IF_NULL(conn); MS_EXCEPTION_IF_NULL(meta); MS_EXCEPTION_IF_NULL(data); + MS_LOG(INFO) << "This node receive a scale out done from scheduler."; if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) { MS_LOG(WARNING) << "Server response message failed."; } is_ready_ = true; - current_cluster_state_ = ClusterState::CLUSTER_READY; + UpdateClusterState(ClusterState::CLUSTER_READY); + PersistMetaData(); } void AbstractNode::ProcessScaleInDone(const std::shared_ptr &conn, @@ -727,7 +731,8 @@ void AbstractNode::ProcessScaleInDone(const std::shared_ptr &conn MS_LOG(WARNING) << "Server response message failed."; } is_ready_ = true; - current_cluster_state_ = ClusterState::CLUSTER_READY; + UpdateClusterState(ClusterState::CLUSTER_READY); + PersistMetaData(); } void AbstractNode::ProcessEvent(const std::shared_ptr &conn, const std::shared_ptr &meta, @@ -736,12 +741,17 @@ void AbstractNode::ProcessEvent(const std::shared_ptr &conn, cons MS_EXCEPTION_IF_NULL(meta); MS_EXCEPTION_IF_NULL(data); EventRespMessage event_resp_message; - CHECK_RETURN_TYPE(event_resp_message.ParseFromArray(data, SizeToInt(size))); + event_resp_message.ParseFromArray(data, SizeToInt(size)); uint32_t event = event_resp_message.event(); if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) { MS_LOG(WARNING) << "Server response message failed."; } - OnCustomEventCallback(event); + MS_LOG(INFO) << "This node receive a event:" << event; + if (event == static_cast(ps::UserDefineEvent::kNodeTimeout)) { + OnEventCallback(ClusterEvent::NODE_TIMEOUT); + } else { + OnCustomEventCallback(event); + } } void AbstractNode::ProcessScaleOut(const std::shared_ptr &conn, const std::shared_ptr &meta, @@ -751,7 +761,7 @@ void AbstractNode::ProcessScaleOut(const std::shared_ptr &conn, c MS_EXCEPTION_IF_NULL(data); ScaleOutMessage scale_out_message; - CHECK_RETURN_TYPE(scale_out_message.ParseFromArray(data, SizeToInt(size))); + scale_out_message.ParseFromArray(data, SizeToInt(size)); int32_t worker_num = scale_out_message.worker_num(); int32_t server_num = scale_out_message.server_num(); MS_LOG(WARNING) << "The scale out worker num:" << worker_num << ", the server num:" << server_num; @@ -760,7 +770,7 @@ void AbstractNode::ProcessScaleOut(const std::shared_ptr &conn, c MS_LOG(WARNING) << "Server response message failed."; } OnEventCallback(ClusterEvent::READY_FOR_SCALE_OUT); - current_cluster_state_ = ClusterState::CLUSTER_SCALE_OUT; + UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT); is_ready_ = false; } @@ -771,7 +781,7 @@ void AbstractNode::ProcessScaleIn(const std::shared_ptr &conn, co MS_EXCEPTION_IF_NULL(data); ScaleInMessage scale_in_message; - CHECK_RETURN_TYPE(scale_in_message.ParseFromArray(data, SizeToInt(size))); + scale_in_message.ParseFromArray(data, SizeToInt(size)); int32_t worker_num = scale_in_message.worker_num(); int32_t server_num = scale_in_message.server_num(); MS_LOG(WARNING) << "The scale in worker num:" << worker_num << ", the server num:" << server_num; @@ -789,7 +799,44 @@ void AbstractNode::ProcessScaleIn(const std::shared_ptr &conn, co MS_LOG(WARNING) << "Server response message failed."; } OnEventCallback(ClusterEvent::READY_FOR_SCALE_IN); - current_cluster_state_ = ClusterState::CLUSTER_SCALE_IN; + UpdateClusterState(ClusterState::CLUSTER_SCALE_IN); + is_ready_ = false; +} + +void AbstractNode::ProcessSchedulerRecovery(const std::shared_ptr &conn, + const std::shared_ptr &meta, const Protos &, const void *data, + size_t size) { + MS_EXCEPTION_IF_NULL(conn); + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(data); + if (is_connected_to_scheduler_.load()) { + MS_LOG(WARNING) << "This node has been connected to scheduler."; + return; + } + SendMetadataMessage scheduler_recovery_message; + (void)scheduler_recovery_message.ParseFromArray(data, SizeToInt(size)); + worker_num_ = scheduler_recovery_message.worker_num(); + server_num_ = scheduler_recovery_message.server_num(); + uint32_t rank_id = scheduler_recovery_message.rank_id(); + + MS_LOG(INFO) << "[Scheduler Recovery]: The scheduler recovery worker num:" << worker_num_ + << ", the server num:" << server_num_ << ", the rank id: " << rank_id; + + if (!server_->SendMessage(conn, meta, Protos::RAW, data, size)) { + MS_LOG(WARNING) << "[Scheduler Recovery]: Server response message failed."; + } + MS_LOG(INFO) << "[Scheduler Recovery]: Server response message success!."; + + if (!InitClientToScheduler()) { + MS_LOG(WARNING) << "[Scheduler Recovery]: Server node connect to scheduler timedout!"; + } + + Register(client_to_scheduler_); + std::lock_guard lock(client_mutex_); + connected_nodes_.clear(); + MS_LOG(INFO) << "[Scheduler Recovery]: This node connect to scheduler successful!"; + + UpdateClusterState(ClusterState::CLUSTER_SCHEDULER_RECOVERY); is_ready_ = false; } @@ -819,6 +866,14 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) { return res; } +void AbstractNode::InitClientToServer() { + // create tcp client to myself in case of event dispatch failed when Send msg to server 0 failed + client_to_server_ = std::make_shared(node_info_.ip_, node_info_.port_, config_.get()); + MS_EXCEPTION_IF_NULL(client_to_server_); + client_to_server_->Init(); + MS_LOG(INFO) << "The node start a tcp client to this node!"; +} + bool AbstractNode::InitClientToScheduler() { if (config_ == nullptr) { MS_LOG(WARNING) << "The config is empty."; @@ -843,7 +898,6 @@ bool AbstractNode::InitClientToScheduler() { MsException::Instance().SetException(); } }); - client_to_scheduler_->Init(); client_to_scheduler_thread_ = std::make_unique([&]() { MS_LOG(INFO) << "The node start a tcp client!"; @@ -851,11 +905,14 @@ bool AbstractNode::InitClientToScheduler() { }); client_to_scheduler_thread_->detach(); + client_to_scheduler_->set_connected_callback([&]() { is_connected_to_scheduler_ = true; }); + client_to_scheduler_->set_disconnected_callback([&]() { std::this_thread::sleep_for(std::chrono::milliseconds(PSContext::instance()->cluster_config().connect_interval)); if (is_ready_.load() == false) { client_to_scheduler_->Init(); } + is_connected_to_scheduler_ = false; }); bool wait_res = client_to_scheduler_->WaitConnected(); if (!wait_res) { @@ -892,6 +949,9 @@ const std::shared_ptr &AbstractNode::GetOrCreateTcpClient(const uint3 case NodeCommand::COLLECTIVE_SEND_DATA: MS_LOG(DEBUG) << "The Node id:" << node_info_.node_id_ << " receive a collective_send_data message response!"; break; + case NodeCommand::SEND_EVENT: + MS_LOG(INFO) << "The Node id:" << node_info_.node_id_ << " receive a send_event command message response!"; + break; default: MS_LOG(EXCEPTION) << "The cmd:" << meta->cmd() << " is not supported!"; } @@ -964,8 +1024,9 @@ void AbstractNode::ProcessSendData(const std::shared_ptr &conn, c if (size > 0) { size_t dest_size = size; size_t src_size = size; - if (memcpy_s(res.get(), dest_size, data, src_size) != EOK) { - MS_LOG(EXCEPTION) << "The memcpy_s error"; + auto ret = memcpy_s(res.get(), dest_size, data, src_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } } MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) @@ -1066,6 +1127,7 @@ void AbstractNode::InitServerHandler() { server_handler_[NodeCommand::SCALE_OUT_DONE] = &AbstractNode::ProcessScaleOutDone; server_handler_[NodeCommand::SCALE_IN_DONE] = &AbstractNode::ProcessScaleInDone; server_handler_[NodeCommand::SEND_EVENT] = &AbstractNode::ProcessEvent; + server_handler_[NodeCommand::SCHEDULER_RECOVERY] = &AbstractNode::ProcessSchedulerRecovery; } void AbstractNode::InitNodeInfo(const NodeRole &role) { @@ -1090,8 +1152,8 @@ void AbstractNode::InitNodeInfo(const NodeRole &role) { } void AbstractNode::InitNodeNum() { - worker_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_worker_num); - server_num_ = SizeToInt(PSContext::instance()->cluster_config().initial_server_num); + worker_num_ = UintToInt(PSContext::instance()->cluster_config().initial_worker_num); + server_num_ = UintToInt(PSContext::instance()->cluster_config().initial_server_num); scheduler_ip_ = PSContext::instance()->cluster_config().scheduler_host; scheduler_port_ = PSContext::instance()->cluster_config().scheduler_port; MS_LOG(INFO) << "The worker num:" << worker_num_ << ", the server num:" << server_num_ @@ -1104,7 +1166,10 @@ bool AbstractNode::Recover() { MS_LOG(INFO) << "The node is support recovery."; node_recovery_ = std::make_unique(this); MS_EXCEPTION_IF_NULL(node_recovery_); - node_recovery_->Initialize(config_->Get(kKeyRecovery, "")); + if (!node_recovery_->Initialize(config_->Get(kKeyRecovery, ""))) { + MS_LOG(ERROR) << "Initializing node recovery failed."; + return false; + } return node_recovery_->Recover(); } return false; @@ -1132,7 +1197,7 @@ void AbstractNode::OnCustomEventCallback(const uint32_t &event) { } } -bool AbstractNode::IsWorkerOrServer0(const mindspore::HashMap &info) { +bool AbstractNode::IsWorkerOrServer0(const std::unordered_map &info) { for (const auto &it : info) { if (it.second.is_alive == true && it.second.node_role_ == NodeRole::WORKER) { return true; @@ -1149,7 +1214,12 @@ void AbstractNode::CreateTcpServer() { MS_EXCEPTION_IF_NULL(config_); std::string interface; std::string server_ip; - CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); + if (ps::PSContext::instance()->server_mode().empty()) { + // If the server mode is not set, use 127.0.0.1 as server ip address for distributed learning. + server_ip = "127.0.0.1"; + } else { + CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); + } server_ = std::make_shared(server_ip, 0, config_.get()); MS_EXCEPTION_IF_NULL(server_); server_->SetMessageCallback([&](const std::shared_ptr &conn, const std::shared_ptr &meta, @@ -1179,6 +1249,29 @@ void AbstractNode::CreateTcpServer() { MS_EXCEPTION_IF_NULL(server_thread_); server_thread_->detach(); } + +void AbstractNode::UpdateClusterState(const ClusterState &state) { + std::lock_guard lk(cluster_state_mutex_); + MS_LOG(INFO) << "[state]: Cluster state change from:" << CommUtil::ClusterStateToString(current_cluster_state_) + << " to " << CommUtil::ClusterStateToString(state); + current_cluster_state_ = state; +} + +void AbstractNode::PersistMetaData() { + if (node_recovery_ == nullptr) { + MS_LOG(WARNING) << "node recovery is null, so don't persist meta data"; + return; + } + if (config_->Exists(kKeyRecovery)) { + ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); + clusterConfig.scheduler_host = this->scheduler_ip(); + clusterConfig.scheduler_port = this->scheduler_port(); + clusterConfig.initial_worker_num = worker_num_; + clusterConfig.initial_server_num = server_num_; + + node_recovery_->Persist(clusterConfig); + } +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/abstract_node.h b/mindspore/ccsrc/ps/core/abstract_node.h index f771f9b70e8..e2dfd75a9a0 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.h +++ b/mindspore/ccsrc/ps/core/abstract_node.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,8 +22,8 @@ #include #include #include +#include -#include "utils/hash_map.h" #include "ps/core/node.h" #include "ps/core/communicator/message.h" #include "ps/core/follower_scaler.h" @@ -44,10 +44,12 @@ class AbstractNode : public Node { : heart_beat_thread_(nullptr), client_to_scheduler_thread_(nullptr), client_to_scheduler_(nullptr), + client_to_server_(nullptr), server_(nullptr), server_thread_(nullptr), worker_num_(-1), server_num_(-1), + is_connected_to_scheduler_(false), is_current_node_scale_in_(false), follower_scaler_(nullptr), node_recovery_(nullptr), @@ -94,14 +96,14 @@ class AbstractNode : public Node { void RegisterCustomEventCallback(const uint32_t &event, const EventCallback &event_cb); bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &data, size_t len, int command, - const uint32_t &timeout = kTimeoutInSeconds); + const uint32_t &timeout = kCommTimeoutInSeconds); bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, - const std::vector &lens, int command, const uint32_t &timeout = kTimeoutInSeconds); + const std::vector &lens, int command, const uint32_t &timeout = kCommTimeoutInSeconds); bool Send(const NodeRole &node_role, const uint32_t &rank_id, const DataPtr &message, size_t len, int command, - VectorPtr *output, const uint32_t &timeout = kTimeoutInSeconds); + VectorPtr *output, const uint32_t &timeout = kCommTimeoutInSeconds); bool Send(const NodeRole &node_role, const std::vector &rank_ids, const std::vector &data, const std::vector &data_lens, int command, std::vector *output, - const uint32_t &timeout = kTimeoutInSeconds); + const uint32_t &timeout = kCommTimeoutInSeconds); uint64_t CollectiveSendAsync(const NodeRole &node_role, const uint32_t &rank_id, const void *data, size_t size); std::pair CollectiveReceiveAsync(const NodeRole &node_role, const uint32_t &rank_id, @@ -170,6 +172,10 @@ class AbstractNode : public Node { void ProcessScaleInDone(const std::shared_ptr &conn, const std::shared_ptr &meta, const Protos &protos, const void *data, size_t size); + // The worker/server processes the scheduler recovery message from scheduelr + void ProcessSchedulerRecovery(const std::shared_ptr &conn, const std::shared_ptr &meta, + const Protos &, const void *data, size_t size); + // The worker/server processes the SEND_EVENT message from scheduelr void ProcessEvent(const std::shared_ptr &conn, const std::shared_ptr &meta, const Protos &protos, const void *data, size_t size); @@ -180,6 +186,7 @@ class AbstractNode : public Node { bool Disconnect(const std::shared_ptr &client, const uint32_t &timeout); bool WaitForDisconnect(const uint32_t &timeout); bool InitClientToScheduler(); + void InitClientToServer(); const std::shared_ptr &GetOrCreateTcpClient(const uint32_t &rank_id, const NodeRole &role = NodeRole::SERVER); bool SendMessageSync(const std::shared_ptr &client, const CommMessage &message, @@ -212,14 +219,18 @@ class AbstractNode : public Node { // Trigger the callback corresponding to the custom event. void OnCustomEventCallback(const uint32_t &event); - bool IsWorkerOrServer0(const mindspore::HashMap &info); + bool IsWorkerOrServer0(const std::unordered_map &info); void CreateTcpServer(); + void UpdateClusterState(const ClusterState &state); + + void PersistMetaData(); + std::unique_ptr heart_beat_thread_; std::unique_ptr client_to_scheduler_thread_; std::shared_ptr client_to_scheduler_; - + std::shared_ptr client_to_server_; // the key is: , the value is: std::map, std::pair> nodes_address_; // the map's key is: rank_id @@ -233,13 +244,13 @@ class AbstractNode : public Node { std::condition_variable receive_cond_; // the key is rank_id, the value is rank_id's expected request_id - mindspore::HashMap expected_rank_request_ids_; + std::unordered_map expected_rank_request_ids_; // the key is rank_id, the value is rank_id's actual request_id - mindspore::HashMap actual_rank_request_ids_; + std::unordered_map actual_rank_request_ids_; std::mutex rank_request_ids_mutex; timeval scheduler_time_{0, 0}; - mindspore::HashMap handlers_; - mindspore::HashMap server_handler_; + std::unordered_map handlers_; + std::unordered_map server_handler_; // Workers and servers launch the server to process command: FINISH,SCALE_OUT,SCALE_IN,SEND_METADATA std::shared_ptr server_; @@ -247,7 +258,7 @@ class AbstractNode : public Node { int32_t worker_num_; int32_t server_num_; - + std::atomic is_connected_to_scheduler_; // Identify whether the current node is a scale in node. std::atomic is_current_node_scale_in_; @@ -273,11 +284,12 @@ class AbstractNode : public Node { uint16_t scheduler_port_; // Synchronize all node metadata from the scheduler. - mindspore::HashMap all_nodes_info_; + std::unordered_map all_nodes_info_; RequestHandler request_handler_; - mindspore::HashMap> communicators_; + std::unordered_map> communicators_; std::mutex communicator_mutex_; + std::mutex cluster_state_mutex_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/cluster_config.h b/mindspore/ccsrc/ps/core/cluster_config.h index a292aa02b52..4a5054d1a5a 100644 --- a/mindspore/ccsrc/ps/core/cluster_config.h +++ b/mindspore/ccsrc/ps/core/cluster_config.h @@ -21,8 +21,10 @@ #include #include #include +#include #include "utils/log_adapter.h" +#include "ps/core/node_info.h" namespace mindspore { namespace ps { @@ -38,10 +40,12 @@ struct ClusterConfig { scheduler_host(host), scheduler_port(port), heartbeat_timeout(30), - cluster_available_timeout(300), + cluster_available_timeout(900), connect_interval(3000), - scheduler_timeout(30) {} - + scheduler_timeout(30), + initial_total_node_num(0), + initial_next_worker_rank_id(0), + initial_next_server_rank_id(0) {} // Configure through environment variables:MS_WORKER_NUM uint32_t initial_worker_num; // Configure through environment variables:MS_SERVER_NUM @@ -59,6 +63,11 @@ struct ClusterConfig { uint32_t connect_interval; // When the scheduler exits, the worker and server can continue to work for 5 hours int64_t scheduler_timeout; + // the node that has bean registered to scheduler + std::unordered_map initial_registered_nodes_infos; + uint32_t initial_total_node_num; + uint32_t initial_next_worker_rank_id; + uint32_t initial_next_server_rank_id; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/cluster_metadata.h b/mindspore/ccsrc/ps/core/cluster_metadata.h index 271bfe5f9a2..6e70fbc60b8 100644 --- a/mindspore/ccsrc/ps/core/cluster_metadata.h +++ b/mindspore/ccsrc/ps/core/cluster_metadata.h @@ -32,7 +32,6 @@ namespace core { */ struct ClusterMetadata { ClusterMetadata(const uint32_t &worker, const uint32_t &server) : worker_num(worker), server_num(server) {} - uint32_t worker_num; uint32_t server_num; }; diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index 6af15c8851e..b5ca3a2a2a6 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -96,6 +96,28 @@ void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *i freeifaddrs(if_address); } +std::string CommUtil::GetLoopBackInterfaceName() { + struct ifaddrs *if_address = nullptr; + struct ifaddrs *ifa = nullptr; + + if (getifaddrs(&if_address) == -1) { + MS_LOG(WARNING) << "Get ifaddrs failed."; + } + for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) { + if (ifa->ifa_addr == nullptr) { + continue; + } + + if (ifa->ifa_flags & IFF_LOOPBACK) { + MS_LOG(INFO) << "Loop back interface name is " << ifa->ifa_name; + return ifa->ifa_name; + } + } + MS_EXCEPTION_IF_NULL(if_address); + freeifaddrs(if_address); + return ""; +} + std::string CommUtil::GenerateUUID() { std::stringstream ss; int i; @@ -135,6 +157,18 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) { MS_LOG(EXCEPTION) << "The node role:" << role << " is illegal!"; } } + +NodeRole CommUtil::StringToNodeRole(const std::string &roleStr) { + if (roleStr == "SCHEDULER") { + return NodeRole::SCHEDULER; + } else if (roleStr == "SERVER") { + return NodeRole::SERVER; + } else if (roleStr == "WORKER") { + return NodeRole::WORKER; + } else { + MS_LOG(EXCEPTION) << "The node role string:" << roleStr << " is illegal!"; + } +} bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num, const int32_t &total_server_num) { if (node_role == NodeRole::SERVER && (rank_id > IntToUint(total_server_num) - 1)) { @@ -183,7 +217,7 @@ bool CommUtil::IsFileExists(const std::string &file) { } std::string CommUtil::ClusterStateToString(const ClusterState &state) { - MS_LOG(INFO) << "The cluster state:" << state; + MS_LOG(DEBUG) << "The cluster state:" << state; if (state < SizeToInt(kClusterState.size())) { return kClusterState.at(state); } else { @@ -191,21 +225,50 @@ std::string CommUtil::ClusterStateToString(const ClusterState &state) { } } -std::string CommUtil::ParseConfig(const Configuration &config, const std::string &data) { +std::string CommUtil::ParseConfig(const Configuration &config, const std::string &key) { if (!config.IsInitialized()) { MS_LOG(INFO) << "The config is not initialized."; return ""; } - if (!const_cast(config).Exists(data)) { - MS_LOG(INFO) << "The data:" << data << " is not exist."; + if (!const_cast(config).Exists(key)) { + MS_LOG(INFO) << "The key:" << key << " is not exist."; return ""; } - std::string path = config.GetString(data, ""); + std::string path = config.GetString(key, ""); return path; } +bool CommUtil::verifyCertTimeStamp(const X509 *cert) { + ASN1_TIME *start = X509_getm_notBefore(cert); + ASN1_TIME *end = X509_getm_notAfter(cert); + + int day = 0; + int sec = 0; + int ret = ASN1_TIME_diff(&day, &sec, start, NULL); + if (ret != 1) { + return false; + } + + if (day < 0 || sec < 0) { + MS_LOG(ERROR) << "cert start time is later than now time."; + return false; + } + day = 0; + sec = 0; + ret = ASN1_TIME_diff(&day, &sec, NULL, end); + if (ret != 1) { + return false; + } + + if (day < 0 || sec < 0) { + MS_LOG(ERROR) << "cert end time is sooner than now time."; + return false; + } + return true; +} + bool CommUtil::VerifyCertTime(const X509 *cert, int64_t time) { MS_EXCEPTION_IF_NULL(cert); ASN1_TIME *start = X509_getm_notBefore(cert); @@ -249,61 +312,55 @@ bool CommUtil::VerifyCRL(const X509 *cert, const std::string &crl_path) { MS_ERROR_IF_NULL_W_RET_VAL(cert, false); BIO *bio = BIO_new_file(crl_path.c_str(), "r"); MS_ERROR_IF_NULL_W_RET_VAL(bio, false); - X509_CRL *root_crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr); - MS_ERROR_IF_NULL_W_RET_VAL(root_crl, false); - EVP_PKEY *evp_pkey = X509_get_pubkey(const_cast(cert)); - MS_ERROR_IF_NULL_W_RET_VAL(evp_pkey, false); - int ret = X509_CRL_verify(root_crl, evp_pkey); + X509_CRL *root_crl = nullptr; + EVP_PKEY *evp_pkey = nullptr; + bool result = true; + do { + root_crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr); + + MS_ERROR_IF_NULL_W_RET_VAL(root_crl, false); + evp_pkey = X509_get_pubkey(const_cast(cert)); + MS_ERROR_IF_NULL_W_RET_VAL(evp_pkey, false); + + int ret = X509_CRL_verify(root_crl, evp_pkey); + if (ret == 1) { + MS_LOG(WARNING) << "Equip cert in root crl, verify failed"; + result = false; + break; + } + } while (0); + BIO_free_all(bio); - if (ret == 1) { - MS_LOG(WARNING) << "Equip cert in root crl, verify failed"; - return false; - } + EVP_PKEY_free(evp_pkey); + X509_CRL_free(root_crl); MS_LOG(INFO) << "VerifyCRL success."; - return true; + return result; } -bool CommUtil::VerifyCommonName(const X509 *cert, const std::string &ca_path) { - MS_ERROR_IF_NULL_W_RET_VAL(cert, false); - X509 *cert_temp = const_cast(cert); - char subject_cn[256] = ""; - char issuer_cn[256] = ""; - X509_NAME *subject_name = X509_get_subject_name(cert_temp); - X509_NAME *issuer_name = X509_get_issuer_name(cert_temp); - MS_ERROR_IF_NULL_W_RET_VAL(subject_name, false); - MS_ERROR_IF_NULL_W_RET_VAL(issuer_name, false); - if (!X509_NAME_get_text_by_NID(subject_name, NID_commonName, subject_cn, sizeof(subject_cn))) { - MS_LOG(WARNING) << "Get text by nid failed."; - return false; - } - if (!X509_NAME_get_text_by_NID(issuer_name, NID_commonName, issuer_cn, sizeof(issuer_cn))) { - MS_LOG(WARNING) << "Get text by nid failed."; - return false; - } - MS_LOG(INFO) << "the subject:" << subject_cn << ", the issuer:" << issuer_cn; +bool CommUtil::VerifyCommonName(const X509 *caCert, const X509 *subCert) { + MS_EXCEPTION_IF_NULL(caCert); + MS_EXCEPTION_IF_NULL(subCert); + char caSubjectCN[256] = ""; + char subIssuerCN[256] = ""; - BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r"); - MS_EXCEPTION_IF_NULL(ca_bio); - X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr); - MS_EXCEPTION_IF_NULL(ca_cert); - char ca_subject_cn[256] = ""; - char ca_issuer_cn[256] = ""; - X509_NAME *ca_subject_name = X509_get_subject_name(ca_cert); - X509_NAME *ca_issuer_name = X509_get_issuer_name(ca_cert); - MS_ERROR_IF_NULL_W_RET_VAL(ca_subject_name, false); - MS_ERROR_IF_NULL_W_RET_VAL(ca_issuer_name, false); - if (!X509_NAME_get_text_by_NID(ca_subject_name, NID_commonName, ca_subject_cn, sizeof(subject_cn))) { - MS_LOG(WARNING) << "Get text by nid failed."; + X509_NAME *caSubjectX509CN = X509_get_subject_name(caCert); + X509_NAME *subIssuerX509CN = X509_get_issuer_name(subCert); + + int ret = X509_NAME_get_text_by_NID(caSubjectX509CN, NID_commonName, caSubjectCN, sizeof(caSubjectCN)); + if (ret < 0) { return false; } - if (!X509_NAME_get_text_by_NID(ca_issuer_name, NID_commonName, ca_issuer_cn, sizeof(issuer_cn))) { - MS_LOG(WARNING) << "Get text by nid failed."; + ret = X509_NAME_get_text_by_NID(subIssuerX509CN, NID_commonName, subIssuerCN, sizeof(subIssuerCN)); + if (ret < 0) { return false; } - MS_LOG(INFO) << "the subject:" << ca_subject_cn << ", the issuer:" << ca_issuer_cn; - BIO_free_all(ca_bio); - if (strcmp(issuer_cn, ca_subject_cn) != 0) { + + std::string caSubjectCNStr = caSubjectCN; + std::string subIssuerCNStr = subIssuerCN; + + if (caSubjectCNStr != subIssuerCNStr) { + MS_LOG(EXCEPTION) << "root CA cert subject cn is not equal with equip CA cert issuer cn."; return false; } return true; @@ -330,19 +387,170 @@ bool CommUtil::VerifyCipherList(const std::vector &list) { return true; } -void CommUtil::InitOpenSSLEnv() { - if (!SSL_library_init()) { - MS_LOG(EXCEPTION) << "SSL_library_init failed."; +bool CommUtil::verifyCertKeyID(const X509 *caCert, const X509 *subCert) { + MS_EXCEPTION_IF_NULL(caCert); + MS_EXCEPTION_IF_NULL(subCert); + int crit = 0; + ASN1_OCTET_STRING *skid = + reinterpret_cast(X509_get_ext_d2i(caCert, NID_subject_key_identifier, &crit, NULL)); + MS_EXCEPTION_IF_NULL(skid); + const int keyidLen = 512; + char subject_keyid[keyidLen] = {0}; + for (int i = 0; i < skid->length; i++) { + char keyid[8] = {0}; + int base = keyidLen; + (void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)skid->data[i]); + int ret = strcat_s(subject_keyid, base, keyid); + if (ret == -1) { + return false; + } } - if (!ERR_load_crypto_strings()) { - MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed."; + + AUTHORITY_KEYID *akeyid = + reinterpret_cast(X509_get_ext_d2i(subCert, NID_authority_key_identifier, &crit, NULL)); + MS_EXCEPTION_IF_NULL(akeyid); + MS_EXCEPTION_IF_NULL(akeyid->keyid); + char issuer_keyid[keyidLen] = {0}; + for (int i = 0; i < akeyid->keyid->length; i++) { + char keyid[8] = {0}; + int base = keyidLen; + (void)sprintf_s(keyid, sizeof(keyid), "%x ", (uint32_t)(akeyid->keyid->data[i])); + int ret = strcat_s(issuer_keyid, base, keyid); + if (ret == -1) { + return false; + } } - if (!SSL_load_error_strings()) { - MS_LOG(EXCEPTION) << "SSL_load_error_strings failed."; + + std::string subject_keyid_str = subject_keyid; + std::string issuer_keyid_str = issuer_keyid; + if (subject_keyid_str != issuer_keyid_str) { + return false; } - if (!OpenSSL_add_all_algorithms()) { - MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed."; + return true; +} + +bool CommUtil::verifySingature(const X509 *caCert, const X509 *subCert) { + MS_EXCEPTION_IF_NULL(caCert); + MS_EXCEPTION_IF_NULL(subCert); + EVP_PKEY *caCertPubKey = X509_get_pubkey(const_cast(caCert)); + + int ret = 0; + ret = X509_verify(const_cast(subCert), caCertPubKey); + if (ret != 1) { + EVP_PKEY_free(caCertPubKey); + MS_LOG(ERROR) << "sub cert verify is failed"; + return false; } + MS_LOG(INFO) << "verifyCAChain success."; + EVP_PKEY_free(caCertPubKey); + return true; +} + +bool CommUtil::verifyExtendedAttributes(const X509 *caCert) { + MS_EXCEPTION_IF_NULL(caCert); + int cirt = 0; + BASIC_CONSTRAINTS *bcons = + reinterpret_cast(X509_get_ext_d2i(caCert, NID_basic_constraints, &cirt, NULL)); + if (bcons == nullptr) { + return false; + } + if (!bcons->ca) { + MS_LOG(ERROR) << "Subject Type is End Entity."; + return false; + } + MS_LOG(INFO) << "Subject Type is CA."; + return true; +} + +void CommUtil::verifyCertPipeline(const X509 *caCert, const X509 *subCert) { + if (!CommUtil::VerifyCommonName(caCert, subCert)) { + MS_LOG(EXCEPTION) << "Verify common name failed."; + } + + if (!CommUtil::verifySingature(caCert, subCert)) { + MS_LOG(EXCEPTION) << "Verify Singature failed."; + } + + if (!CommUtil::verifyExtendedAttributes(caCert)) { + MS_LOG(EXCEPTION) << "Verify Extended Attributes failed."; + } + + if (!CommUtil::verifyCertKeyID(caCert, subCert)) { + MS_LOG(EXCEPTION) << "Verify Cert KeyID failed."; + } + + if (!CommUtil::verifyCertTimeStamp(caCert) || !CommUtil::verifyCertTimeStamp(subCert)) { + MS_LOG(EXCEPTION) << "Verify Cert Time failed."; + } +} + +bool CommUtil::checkCRLTime(const std::string &crlPath) { + if (!IsFileExists(crlPath)) { + return true; + } + BIO *bio = BIO_new_file(crlPath.c_str(), "r"); + if (bio == nullptr) { + return true; + } + bool result = true; + X509_CRL *crl = nullptr; + do { + crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr); + if (crl == nullptr) { + MS_LOG(WARNING) << "crl is nullptr. return true."; + result = true; + break; + } + const ASN1_TIME *lastUpdate = X509_CRL_get0_lastUpdate(crl); + const ASN1_TIME *nextUpdate = X509_CRL_get0_nextUpdate(crl); + + int day = 0; + int sec = 0; + int ret = ASN1_TIME_diff(&day, &sec, lastUpdate, NULL); + if (ret != 1) { + result = false; + break; + } + + if (day < 0 || sec < 0) { + MS_LOG(ERROR) << "crl start time is later than now time."; + result = false; + break; + } + day = 0; + sec = 0; + ret = ASN1_TIME_diff(&day, &sec, NULL, nextUpdate); + if (ret != 1) { + result = false; + break; + } + + if (day < 0 || sec < 0) { + MS_LOG(WARNING) << "crl update time is sooner than now time. please update crl"; + } + MS_LOG(INFO) << "verifyCRL time success."; + } while (0); + + X509_CRL_free(crl); + BIO_free_all(bio); + return result; +} + +std::string CommUtil::BoolToString(bool alive) { + if (alive) { + return "True"; + } else { + return "False"; + } +} + +bool CommUtil::StringToBool(const std::string &alive) { + if (alive == "True") { + return true; + } else if (alive == "False") { + return false; + } + return false; } } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index f0dd069dfc3..500af5c8ebb 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -44,6 +44,7 @@ #include #include #include +#include #include #include @@ -99,8 +100,12 @@ class CommUtil { static bool CheckIp(const std::string &ip); static bool CheckPort(const uint16_t &port); static void GetAvailableInterfaceAndIP(std::string *interface, std::string *ip); + static std::string GetLoopBackInterfaceName(); static std::string GenerateUUID(); static std::string NodeRoleToString(const NodeRole &role); + static NodeRole StringToNodeRole(const std::string &roleStr); + static std::string BoolToString(bool alive); + static bool StringToBool(const std::string &alive); static bool ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id, const int32_t &total_worker_num, const int32_t &total_server_num); static bool Retry(const std::function &func, size_t max_attempts, size_t interval_milliseconds); @@ -112,19 +117,21 @@ class CommUtil { static std::string ClusterStateToString(const ClusterState &state); // Parse the configuration file according to the key. - static std::string ParseConfig(const Configuration &config, const std::string &data); + static std::string ParseConfig(const Configuration &config, const std::string &key); // verify valid of certificate time static bool VerifyCertTime(const X509 *cert, int64_t time = 0); + static bool verifyCertTimeStamp(const X509 *cert); // verify valid of equip certificate with CRL static bool VerifyCRL(const X509 *cert, const std::string &crl_path); - // Check the common name of the certificate - static bool VerifyCommonName(const X509 *cert, const std::string &ca_path); - // The string is divided according to delim + static bool VerifyCommonName(const X509 *caCert, const X509 *subCert); static std::vector Split(const std::string &s, char delim); - // Check the cipher list of the certificate static bool VerifyCipherList(const std::vector &list); - static void InitOpenSSLEnv(); + static bool verifyCertKeyID(const X509 *caCert, const X509 *subCert); + static bool verifySingature(const X509 *caCert, const X509 *subCert); + static bool verifyExtendedAttributes(const X509 *caCert); + static void verifyCertPipeline(const X509 *caCert, const X509 *subCert); + static bool checkCRLTime(const std::string &crlPath); private: static std::random_device rd; diff --git a/mindspore/ccsrc/ps/core/communicator/communicator_base.cc b/mindspore/ccsrc/ps/core/communicator/communicator_base.cc index f8d0feccf0b..5b0eedd8972 100644 --- a/mindspore/ccsrc/ps/core/communicator/communicator_base.cc +++ b/mindspore/ccsrc/ps/core/communicator/communicator_base.cc @@ -43,7 +43,7 @@ void CommunicatorBase::Join() { return; } -bool CommunicatorBase::running() const { return running_; } +bool CommunicatorBase::running() { return running_; } } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/communicator/communicator_base.h b/mindspore/ccsrc/ps/core/communicator/communicator_base.h index 60b80c208f9..a18b92251e9 100644 --- a/mindspore/ccsrc/ps/core/communicator/communicator_base.h +++ b/mindspore/ccsrc/ps/core/communicator/communicator_base.h @@ -19,10 +19,10 @@ #include #include +#include #include #include -#include "utils/hash_map.h" #include "ps/core/communicator/message_handler.h" #include "utils/log_adapter.h" #include "ps/core/communicator/http_message_handler.h" @@ -57,6 +57,7 @@ enum class TcpUserCommand { kQueryInstance, kEnableFLS, kDisableFLS, + kSyncAfterRecover, kExchangeKeys, kGetKeys }; @@ -84,10 +85,10 @@ class CommunicatorBase { bool SendResponse(const void *rsp_data, size_t rsp_len, const std::shared_ptr &msg_handler); - bool running() const; + bool running(); protected: - mindspore::HashMap msg_callbacks_; + std::unordered_map msg_callbacks_; std::thread running_thread_; bool running_; }; diff --git a/mindspore/ccsrc/ps/core/communicator/http_communicator.cc b/mindspore/ccsrc/ps/core/communicator/http_communicator.cc index a8f376e97a5..b5770b811a1 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_communicator.cc +++ b/mindspore/ccsrc/ps/core/communicator/http_communicator.cc @@ -22,11 +22,11 @@ namespace mindspore { namespace ps { namespace core { bool HttpCommunicator::Start() { + MS_EXCEPTION_IF_NULL(http_server_); MS_LOG(INFO) << "Initialize http server IP:" << ip_ << ", PORT:" << port_; if (!http_server_->InitServer()) { MS_LOG(EXCEPTION) << "The communicator init http server failed."; } - MS_EXCEPTION_IF_NULL(http_server_); if (!http_server_->Start()) { MS_LOG(EXCEPTION) << "Http server starting failed."; } @@ -51,9 +51,11 @@ bool HttpCommunicator::Stop() { } void HttpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) { + MS_LOG(INFO) << "msg_type is: " << msg_type; msg_callbacks_[msg_type] = cb; http_msg_callbacks_[msg_type] = std::bind( [&](std::shared_ptr http_msg) -> void { + MS_EXCEPTION_IF_NULL(http_msg); std::shared_ptr http_msg_handler = std::make_shared(http_msg); MS_EXCEPTION_IF_NULL(http_msg_handler); msg_callbacks_[msg_type](http_msg_handler); @@ -61,7 +63,8 @@ void HttpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const Me }, std::placeholders::_1); - std::string url = "/"; + std::string url = ps::PSContext::instance()->http_url_prefix(); + url += "/"; url += msg_type; MS_EXCEPTION_IF_NULL(http_server_); bool is_succeed = http_server_->RegisterRoute(url, &http_msg_callbacks_[msg_type]); diff --git a/mindspore/ccsrc/ps/core/communicator/http_communicator.h b/mindspore/ccsrc/ps/core/communicator/http_communicator.h index 064f123da4c..c5385cf63ca 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_communicator.h +++ b/mindspore/ccsrc/ps/core/communicator/http_communicator.h @@ -19,7 +19,7 @@ #include #include -#include "utils/hash_map.h" +#include #include "ps/core/communicator/http_server.h" #include "ps/core/communicator/http_message_handler.h" #include "ps/core/communicator/task_executor.h" @@ -46,7 +46,7 @@ class HttpCommunicator : public CommunicatorBase { private: std::shared_ptr task_executor_; std::shared_ptr http_server_; - mindspore::HashMap http_msg_callbacks_; + std::unordered_map http_msg_callbacks_; std::string ip_; uint16_t port_; diff --git a/mindspore/ccsrc/ps/core/communicator/http_message_handler.cc b/mindspore/ccsrc/ps/core/communicator/http_message_handler.cc index 0581ca0c371..f8f7e42af4c 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_message_handler.cc +++ b/mindspore/ccsrc/ps/core/communicator/http_message_handler.cc @@ -265,7 +265,7 @@ void HttpMessageHandler::SimpleResponse(int code, const HttpHeaders &headers, co } void HttpMessageHandler::ErrorResponse(int code, const RequestProcessResult &result) { - nlohmann::json error_json = {{"error_message", result.StatusMessage()}}; + nlohmann::json error_json = {{"error_message", result.StatusMessage()}, {"code", kErrorCode}}; std::string out_error = error_json.dump(); AddRespString(out_error); SetRespCode(code); diff --git a/mindspore/ccsrc/ps/core/communicator/http_request_handler.cc b/mindspore/ccsrc/ps/core/communicator/http_request_handler.cc index 3bb2ef2a08e..e28c84f65b2 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_request_handler.cc +++ b/mindspore/ccsrc/ps/core/communicator/http_request_handler.cc @@ -26,7 +26,7 @@ HttpRequestHandler::~HttpRequestHandler() { } } -bool HttpRequestHandler::Initialize(int fd, const mindspore::HashMap &handlers) { +bool HttpRequestHandler::Initialize(int fd, const std::unordered_map &handlers) { evbase_ = event_base_new(); MS_EXCEPTION_IF_NULL(evbase_); struct evhttp *http = evhttp_new(evbase_); @@ -115,8 +115,7 @@ bufferevent *HttpRequestHandler::BuffereventCallback(event_base *base, void *arg SSL_CTX *ctx = reinterpret_cast(arg); SSL *ssl = SSL_new(ctx); MS_EXCEPTION_IF_NULL(ssl); - bufferevent *bev = - bufferevent_openssl_socket_new(base, -1, ssl, BUFFEREVENT_SSL_ACCEPTING, static_cast(BEV_OPT_CLOSE_ON_FREE)); + bufferevent *bev = bufferevent_openssl_socket_new(base, -1, ssl, BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE); MS_EXCEPTION_IF_NULL(bev); return bev; } diff --git a/mindspore/ccsrc/ps/core/communicator/http_request_handler.h b/mindspore/ccsrc/ps/core/communicator/http_request_handler.h index f905e0e62f4..43048525fd9 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_request_handler.h +++ b/mindspore/ccsrc/ps/core/communicator/http_request_handler.h @@ -25,8 +25,8 @@ #include #include +#include -#include "utils/hash_map.h" #include "utils/log_adapter.h" #include "ps/core/communicator/http_message_handler.h" #include "ps/core/communicator/ssl_http.h" @@ -46,7 +46,7 @@ class HttpRequestHandler { HttpRequestHandler() : evbase_(nullptr) {} virtual ~HttpRequestHandler(); - bool Initialize(int fd, const mindspore::HashMap &handlers); + bool Initialize(int fd, const std::unordered_map &handlers); void Run(); bool Stop(); static bufferevent *BuffereventCallback(event_base *base, void *arg); diff --git a/mindspore/ccsrc/ps/core/communicator/http_server.cc b/mindspore/ccsrc/ps/core/communicator/http_server.cc index e01ad5cef4b..657fd52f4b5 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_server.cc +++ b/mindspore/ccsrc/ps/core/communicator/http_server.cc @@ -68,7 +68,7 @@ bool HttpServer::InitServer() { return false; } - fd_ = ::socket(static_cast(AF_INET), static_cast(SOCK_STREAM), 0); + fd_ = ::socket(AF_INET, SOCK_STREAM, 0); if (fd_ < 0) { MS_LOG(ERROR) << "Socker error!"; return false; @@ -84,7 +84,8 @@ bool HttpServer::InitServer() { } struct sockaddr_in addr; - if (memset_s(&addr, sizeof(addr), 0, sizeof(addr)) != EOK) { + errno_t ret = memset_s(&addr, sizeof(addr), 0, sizeof(addr)); + if (ret != EOK) { MS_LOG(EXCEPTION) << "Memset failed."; } @@ -132,6 +133,7 @@ bool HttpServer::RegisterRoute(const std::string &url, OnRequestReceive *functio if (!function) { return false; } + MS_LOG(INFO) << "request handler url is: " << url; request_handlers_[url] = function; return true; } diff --git a/mindspore/ccsrc/ps/core/communicator/http_server.h b/mindspore/ccsrc/ps/core/communicator/http_server.h index e161ac4e6eb..096dc83adca 100644 --- a/mindspore/ccsrc/ps/core/communicator/http_server.h +++ b/mindspore/ccsrc/ps/core/communicator/http_server.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,8 @@ #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_ #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_HTTP_SERVER_H_ +#include "ps/core/communicator/http_message_handler.h" + #include #include #include @@ -34,10 +36,9 @@ #include #include #include +#include #include -#include "utils/hash_map.h" -#include "ps/core/communicator/http_message_handler.h" #include "ps/core/communicator/http_request_handler.h" namespace mindspore { @@ -77,7 +78,7 @@ class HttpServer { std::vector> worker_threads_; std::vector> http_request_handlers; int32_t backlog_; - mindspore::HashMap request_handlers_; + std::unordered_map request_handlers_; int fd_; }; } // namespace core diff --git a/mindspore/ccsrc/ps/core/communicator/ssl_client.cc b/mindspore/ccsrc/ps/core/communicator/ssl_client.cc index d91cb67ab84..6a6198968f4 100644 --- a/mindspore/ccsrc/ps/core/communicator/ssl_client.cc +++ b/mindspore/ccsrc/ps/core/communicator/ssl_client.cc @@ -1,3 +1,4 @@ + /** * Copyright 2021 Huawei Technologies Co., Ltd * @@ -37,7 +38,18 @@ SSLClient::SSLClient() : ssl_ctx_(nullptr), check_time_thread_(nullptr), running SSLClient::~SSLClient() { CleanSSL(); } void SSLClient::InitSSL() { - CommUtil::InitOpenSSLEnv(); + if (!SSL_library_init()) { + MS_LOG(EXCEPTION) << "SSL_library_init failed."; + } + if (!ERR_load_crypto_strings()) { + MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed."; + } + if (!SSL_load_error_strings()) { + MS_LOG(EXCEPTION) << "SSL_load_error_strings failed."; + } + if (!OpenSSL_add_all_algorithms()) { + MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed."; + } ssl_ctx_ = SSL_CTX_new(SSLv23_client_method()); if (!ssl_ctx_) { MS_LOG(EXCEPTION) << "SSL_CTX_new failed"; @@ -65,6 +77,7 @@ void SSLClient::InitSSL() { EVP_PKEY *pkey = nullptr; X509 *cert = nullptr; STACK_OF(X509) *ca_stack = nullptr; + MS_LOG(INFO) << "cliet cert: " << client_cert; BIO *bio = BIO_new_file(client_cert.c_str(), "rb"); MS_EXCEPTION_IF_NULL(bio); PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr); @@ -83,27 +96,26 @@ void SSLClient::InitSSL() { } // 3. load ca cert. - std::string client_ca = kCAcrt; std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath); if (!CommUtil::IsFileExists(ca_path)) { MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist."; } - client_ca = ca_path; - + BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r"); + MS_EXCEPTION_IF_NULL(ca_bio); + X509 *caCert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr); std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath); if (crl_path.empty()) { MS_LOG(INFO) << "The crl path is empty."; + } else if (!CommUtil::checkCRLTime(crl_path)) { + MS_LOG(EXCEPTION) << "check crl time failed"; } else if (!CommUtil::VerifyCRL(cert, crl_path)) { MS_LOG(EXCEPTION) << "Verify crl failed."; } - if (!CommUtil::VerifyCommonName(cert, client_ca)) { - MS_LOG(EXCEPTION) << "Verify common name failed."; - } + CommUtil::verifyCertPipeline(caCert, cert); SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, 0); - - if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) { + if (!SSL_CTX_load_verify_locations(ssl_ctx_, ca_path.c_str(), nullptr)) { MS_LOG(EXCEPTION) << "SSL load ca location failed!"; } @@ -115,13 +127,21 @@ void SSLClient::InitSSL() { if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) { MS_LOG(EXCEPTION) << "SSL use set cipher list failed!"; } + InitSSLCtx(cert, pkey); + StartCheckCertTime(*config_, cert); - // 4. load client cert - if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) { + EVP_PKEY_free(pkey); + X509_free(caCert); + X509_free(cert); + (void)BIO_free(ca_bio); +} + +void SSLClient::InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey) { + if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast(cert))) { MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; } - if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) { + if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast(pkey))) { MS_LOG(EXCEPTION) << "SSL use private key file failed!"; } @@ -138,7 +158,7 @@ void SSLClient::InitSSL() { MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!"; } - StartCheckCertTime(*config_, cert); + SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel); } void SSLClient::CleanSSL() { diff --git a/mindspore/ccsrc/ps/core/communicator/ssl_client.h b/mindspore/ccsrc/ps/core/communicator/ssl_client.h index c3da7320380..a698d8613bd 100644 --- a/mindspore/ccsrc/ps/core/communicator/ssl_client.h +++ b/mindspore/ccsrc/ps/core/communicator/ssl_client.h @@ -37,6 +37,7 @@ #include "ps/core/comm_util.h" #include "ps/constants.h" #include "ps/core/file_configuration.h" +#include "ps/ps_context.h" namespace mindspore { namespace ps { @@ -60,6 +61,7 @@ class SSLClient { void StartCheckCertTime(const Configuration &config, const X509 *cert); void StopCheckCertTime(); + void InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey); SSL_CTX *ssl_ctx_; std::unique_ptr check_time_thread_; diff --git a/mindspore/ccsrc/ps/core/communicator/ssl_http.cc b/mindspore/ccsrc/ps/core/communicator/ssl_http.cc index a51ac2a1d9c..da9b23de808 100644 --- a/mindspore/ccsrc/ps/core/communicator/ssl_http.cc +++ b/mindspore/ccsrc/ps/core/communicator/ssl_http.cc @@ -1,3 +1,4 @@ + /** * Copyright 2021 Huawei Technologies Co., Ltd * @@ -35,7 +36,18 @@ SSLHTTP::SSLHTTP() : ssl_ctx_(nullptr) { InitSSL(); } SSLHTTP::~SSLHTTP() { CleanSSL(); } void SSLHTTP::InitSSL() { - CommUtil::InitOpenSSLEnv(); + if (!SSL_library_init()) { + MS_LOG(EXCEPTION) << "SSL_library_init failed."; + } + if (!ERR_load_crypto_strings()) { + MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed."; + } + if (!SSL_load_error_strings()) { + MS_LOG(EXCEPTION) << "SSL_load_error_strings failed."; + } + if (!OpenSSL_add_all_algorithms()) { + MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed."; + } ssl_ctx_ = SSL_CTX_new(SSLv23_server_method()); if (!ssl_ctx_) { MS_LOG(EXCEPTION) << "SSL_CTX_new failed"; @@ -79,25 +91,39 @@ void SSLHTTP::InitSSL() { MS_LOG(EXCEPTION) << "PKCS12_parse failed."; } PKCS12_free(p12); + if (cert == nullptr) { + MS_LOG(EXCEPTION) << "the cert is nullptr"; + } + if (pkey == nullptr) { + MS_LOG(EXCEPTION) << "the key is nullptr"; + } + if (!CommUtil::verifyCertTimeStamp(cert)) { + MS_LOG(EXCEPTION) << "Verify Cert Time failed."; + } std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList); + InitSSLCtx(cert, pkey, default_cipher_list); + EVP_PKEY_free(pkey); + X509_free(cert); +} + +void SSLHTTP::InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey, const std::string &default_cipher_list) { if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) { MS_LOG(EXCEPTION) << "SSL use set cipher list failed!"; } - - if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) { + if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast(cert))) { MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; } - if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) { + if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast(pkey))) { MS_LOG(EXCEPTION) << "SSL use private key file failed!"; } if (!SSL_CTX_check_private_key(ssl_ctx_)) { MS_LOG(EXCEPTION) << "SSL check private key file failed!"; } - if (!SSL_CTX_set_options(ssl_ctx_, SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) { MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed."; } + SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel); } void SSLHTTP::CleanSSL() { diff --git a/mindspore/ccsrc/ps/core/communicator/ssl_http.h b/mindspore/ccsrc/ps/core/communicator/ssl_http.h index 56c323bd818..60701c0ea96 100644 --- a/mindspore/ccsrc/ps/core/communicator/ssl_http.h +++ b/mindspore/ccsrc/ps/core/communicator/ssl_http.h @@ -53,7 +53,7 @@ class SSLHTTP { void InitSSL(); void CleanSSL(); - + void InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey, const std::string &default_cipher_list); SSL_CTX *ssl_ctx_; }; } // namespace core diff --git a/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.cc b/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.cc index b0d3fe1b4e4..cbb89fccacf 100644 --- a/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.cc +++ b/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.cc @@ -1,3 +1,4 @@ + /** * Copyright 2021 Huawei Technologies Co., Ltd * @@ -43,7 +44,18 @@ SSLWrapper::SSLWrapper() SSLWrapper::~SSLWrapper() { CleanSSL(); } void SSLWrapper::InitSSL() { - CommUtil::InitOpenSSLEnv(); + if (!SSL_library_init()) { + MS_LOG(EXCEPTION) << "SSL_library_init failed."; + } + if (!ERR_load_crypto_strings()) { + MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed."; + } + if (!SSL_load_error_strings()) { + MS_LOG(EXCEPTION) << "SSL_load_error_strings failed."; + } + if (!OpenSSL_add_all_algorithms()) { + MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed."; + } ssl_ctx_ = SSL_CTX_new(SSLv23_server_method()); if (!ssl_ctx_) { MS_LOG(EXCEPTION) << "SSL_CTX_new failed"; @@ -100,31 +112,42 @@ void SSLWrapper::InitSSL() { std::string crl_path = CommUtil::ParseConfig(*(config_), kCrlPath); if (crl_path.empty()) { MS_LOG(INFO) << "The crl path is empty."; + } else if (!CommUtil::checkCRLTime(crl_path)) { + MS_LOG(EXCEPTION) << "check crl time failed"; } else if (!CommUtil::VerifyCRL(cert, crl_path)) { MS_LOG(EXCEPTION) << "Verify crl failed."; } - std::string client_ca = kCAcrt; std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath); if (!CommUtil::IsFileExists(ca_path)) { MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist."; } - client_ca = ca_path; + BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r"); + MS_EXCEPTION_IF_NULL(ca_bio); + X509 *caCert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr); - if (!CommUtil::VerifyCommonName(cert, client_ca)) { - MS_LOG(EXCEPTION) << "Verify common name failed."; - } + CommUtil::verifyCertPipeline(caCert, cert); SSL_CTX_set_verify(ssl_ctx_, SSL_VERIFY_PEER, 0); - if (!SSL_CTX_load_verify_locations(ssl_ctx_, client_ca.c_str(), nullptr)) { + if (!SSL_CTX_load_verify_locations(ssl_ctx_, ca_path.c_str(), nullptr)) { MS_LOG(EXCEPTION) << "SSL load ca location failed!"; } - if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) { + InitSSLCtx(cert, pkey); + StartCheckCertTime(*config_, cert, ca_path); + + EVP_PKEY_free(pkey); + X509_free(caCert); + X509_free(cert); + (void)BIO_free(ca_bio); +} + +void SSLWrapper::InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey) { + if (!SSL_CTX_use_certificate(ssl_ctx_, const_cast(cert))) { MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; } - if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) { + if (!SSL_CTX_use_PrivateKey(ssl_ctx_, const_cast(pkey))) { MS_LOG(EXCEPTION) << "SSL use private key file failed!"; } @@ -135,12 +158,10 @@ void SSLWrapper::InitSSL() { SSL_OP_NO_TLSv1 | SSL_OP_NO_TLSv1_1)) { MS_LOG(EXCEPTION) << "SSL_CTX_set_options failed."; } - if (!SSL_CTX_set_mode(ssl_ctx_, SSL_MODE_AUTO_RETRY)) { MS_LOG(EXCEPTION) << "SSL set mode auto retry failed!"; } - - StartCheckCertTime(*config_, cert, client_ca); + SSL_CTX_set_security_level(ssl_ctx_, kSecurityLevel); } void SSLWrapper::CleanSSL() { diff --git a/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h b/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h index bacefc28eae..18fad291ba3 100644 --- a/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h +++ b/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h @@ -60,6 +60,7 @@ class SSLWrapper { time_t ConvertAsn1Time(const ASN1_TIME *const time) const; void StartCheckCertTime(const Configuration &config, const X509 *cert, const std::string &ca_path); void StopCheckCertTime(); + void InitSSLCtx(const X509 *cert, const EVP_PKEY *pkey); SSL_CTX *ssl_ctx_; diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_client.cc b/mindspore/ccsrc/ps/core/communicator/tcp_client.cc index b1dc160258d..06da3dbfdfd 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/communicator/tcp_client.cc @@ -94,7 +94,6 @@ void TcpClient::Init() { if (event_base_ == nullptr) { event_base_ = event_base_new(); MS_EXCEPTION_IF_NULL(event_base_); - is_stop_ = false; } sockaddr_in sin{}; @@ -160,7 +159,7 @@ void TcpClient::Stop() { void TcpClient::SetTcpNoDelay(const evutil_socket_t &fd) { const int one = 1; - int ret = setsockopt(fd, static_cast(IPPROTO_TCP), static_cast(TCP_NODELAY), &one, sizeof(int)); + int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int)); if (ret < 0) { MS_LOG(EXCEPTION) << "Set socket no delay failed!"; } @@ -178,10 +177,10 @@ void TcpClient::ReadCallback(struct bufferevent *bev, void *ctx) { auto tcp_client = reinterpret_cast(ctx); char read_buffer[kMessageChunkLength]; - size_t read = 0; + int read = 0; - while ((read = bufferevent_read(bev, &read_buffer, sizeof(read_buffer))) > 0) { - tcp_client->OnReadHandler(read_buffer, read); + while ((read = bufferevent_read(bev, &read_buffer, SizeToInt(sizeof(read_buffer)))) > 0) { + tcp_client->OnReadHandler(read_buffer, IntToSize(read)); } } @@ -252,6 +251,8 @@ void TcpClient::Start() { event_base_mutex_.unlock(); MS_EXCEPTION_IF_NULL(event_base_); int ret = event_base_dispatch(event_base_); + // is_started_ should be false when finish dispatch + is_started_ = false; MSLOG_IF(INFO, ret == 0, NoExceptionType) << "Event base dispatch success!"; MSLOG_IF(mindspore::ERROR, ret == 1, NoExceptionType) << "Event base dispatch failed with no events pending or active!"; diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc index 7cb534191e3..03306823c62 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.cc @@ -90,6 +90,7 @@ bool TcpCommunicator::Stop() { } void TcpCommunicator::RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) { + MS_LOG(INFO) << "msg_type is: " << msg_type; msg_callbacks_.try_emplace(msg_type, cb); return; } diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h index d5d4a07e772..df8a2f4de5e 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_communicator.h @@ -21,7 +21,7 @@ #include #include #include -#include "utils/hash_map.h" +#include #include "proto/ps.pb.h" #include "ps/core/server_node.h" #include "ps/core/cluster_metadata.h" @@ -36,7 +36,7 @@ namespace mindspore { namespace ps { namespace core { -const mindspore::HashMap kUserCommandToMsgType = { +const std::unordered_map kUserCommandToMsgType = { {TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"}, {TcpUserCommand::kCount, "count"}, @@ -61,7 +61,8 @@ const mindspore::HashMap kUserCommandToMsgType = { {TcpUserCommand::kNewInstance, "newInstance"}, {TcpUserCommand::kQueryInstance, "queryInstance"}, {TcpUserCommand::kEnableFLS, "enableFLS"}, - {TcpUserCommand::kDisableFLS, "disableFLS"}}; + {TcpUserCommand::kDisableFLS, "disableFLS"}, + {TcpUserCommand::kSyncAfterRecover, "syncAfterRecover"}}; class TcpCommunicator : public CommunicatorBase { public: @@ -92,8 +93,9 @@ class TcpCommunicator : public CommunicatorBase { MS_ERROR_IF_NULL_W_RET_VAL(msg, false); size_t dest_size = msg_str.size(); size_t src_size = msg_str.size(); - if (memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size) != EOK) { - MS_LOG(EXCEPTION) << "Memcpy_s error"; + auto ret = memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy_s error, error no " << ret; } if (output != nullptr) { diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_msg_handler.cc b/mindspore/ccsrc/ps/core/communicator/tcp_msg_handler.cc index b54ffe6bf25..91320bb1568 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_msg_handler.cc +++ b/mindspore/ccsrc/ps/core/communicator/tcp_msg_handler.cc @@ -20,8 +20,8 @@ namespace mindspore { namespace ps { namespace core { -TcpMsgHandler::TcpMsgHandler(AbstractNode *const abstract_node, const std::shared_ptr &conn, - const std::shared_ptr &meta, const DataPtr &data, size_t size) +TcpMsgHandler::TcpMsgHandler(AbstractNode *abstract_node, const std::shared_ptr &conn, + const std::shared_ptr &meta, DataPtr data, size_t size) : abstract_node_(abstract_node), tcp_conn_(conn), meta_(meta), data_ptr_(data), data_(nullptr), len_(size) { if (data_ptr_ != nullptr) { data_ = data_ptr_.get(); diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_msg_handler.h b/mindspore/ccsrc/ps/core/communicator/tcp_msg_handler.h index 61082f7242a..7d5c577c7fa 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_msg_handler.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_msg_handler.h @@ -28,8 +28,8 @@ namespace ps { namespace core { class TcpMsgHandler : public MessageHandler { public: - TcpMsgHandler(AbstractNode *const abstract_node, const std::shared_ptr &conn, - const std::shared_ptr &meta, const DataPtr &data, size_t size); + TcpMsgHandler(AbstractNode *abstract_node, const std::shared_ptr &conn, + const std::shared_ptr &meta, DataPtr data, size_t size); ~TcpMsgHandler() override = default; void *data() const override; diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_server.cc b/mindspore/ccsrc/ps/core/communicator/tcp_server.cc index 128108b0e0e..f712dc99c4d 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/communicator/tcp_server.cc @@ -175,8 +175,9 @@ void TcpServer::Init() { listener_ = evconnlistener_new_bind(base_, ListenerCallback, reinterpret_cast(this), LEV_OPT_REUSEABLE | LEV_OPT_CLOSE_ON_FREE, -1, reinterpret_cast(&sin), sizeof(sin)); - - MS_EXCEPTION_IF_NULL(listener_); + if (listener_ == nullptr) { + MS_LOG(EXCEPTION) << "bind ip & port failed. please check."; + } if (server_port_ == 0) { struct sockaddr_in sin_bound {}; @@ -306,6 +307,7 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st }); bufferevent_setcb(bev, TcpServer::ReadCallback, nullptr, TcpServer::EventCallback, reinterpret_cast(conn.get())); + MS_LOG(INFO) << "A client is connected, fd is " << fd; if (bufferevent_enable(bev, EV_READ | EV_WRITE) == -1) { MS_LOG(EXCEPTION) << "Buffer event enable read and write failed!"; } @@ -411,7 +413,7 @@ void TcpServer::TimerOnceCallback(evutil_socket_t, int16_t, void *arg) { void TcpServer::SetTcpNoDelay(const evutil_socket_t &fd) { const int one = 1; - int ret = setsockopt(fd, static_cast(IPPROTO_TCP), static_cast(TCP_NODELAY), &one, sizeof(int)); + int ret = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(int)); if (ret < 0) { MS_LOG(EXCEPTION) << "Set socket no delay failed!"; } diff --git a/mindspore/ccsrc/ps/core/configuration.h b/mindspore/ccsrc/ps/core/configuration.h index 4a92de9847f..069eeff0d83 100644 --- a/mindspore/ccsrc/ps/core/configuration.h +++ b/mindspore/ccsrc/ps/core/configuration.h @@ -25,9 +25,10 @@ #include #include #include +#include -#include "utils/hash_map.h" #include "ps/constants.h" +#include "nlohmann/json.hpp" #include "utils/log_adapter.h" namespace mindspore { @@ -51,6 +52,9 @@ class Configuration { // Get configuration data from database or config file. virtual std::string GetString(const std::string &key, const std::string &defaultvalue) const = 0; + // Get configuration vector data from database or config file. + virtual std::vector GetVector(const std::string &key) const = 0; + // Get configuration data from database or config file. virtual int64_t GetInt(const std::string &key, int64_t default_value) const = 0; @@ -59,6 +63,12 @@ class Configuration { // Determine whether the configuration item exists. virtual bool Exists(const std::string &key) const = 0; + + // storage meta data + virtual void PersistFile(const core::ClusterConfig &clusterConfig) const = 0; + + // storage meta data without nodes + virtual void PersistNodes(const core::ClusterConfig &clusterConfig) const = 0; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/file_configuration.cc b/mindspore/ccsrc/ps/core/file_configuration.cc index 9c2be9eded7..edf27a209b8 100644 --- a/mindspore/ccsrc/ps/core/file_configuration.cc +++ b/mindspore/ccsrc/ps/core/file_configuration.cc @@ -50,6 +50,15 @@ std::string FileConfiguration::Get(const std::string &key, const std::string &de return res; } +std::vector FileConfiguration::GetVector(const std::string &key) const { + if (!js.contains(key)) { + MS_LOG(WARNING) << "The key:" << key << " is not exist."; + return std::vector(); + } + + return js.at(key); +} + std::string FileConfiguration::GetString(const std::string &key, const std::string &defaultvalue) const { if (!js.contains(key)) { MS_LOG(WARNING) << "The key:" << key << " is not exist."; @@ -68,13 +77,7 @@ int64_t FileConfiguration::GetInt(const std::string &key, int64_t default_value) return res; } -void FileConfiguration::Put(const std::string &key, const std::string &value) { - std::ofstream output_file(file_path_); - js[key] = value; - output_file << js.dump(); - - output_file.close(); -} +void FileConfiguration::Put(const std::string &key, const std::string &value) { js[key] = value; } bool FileConfiguration::Exists(const std::string &key) const { if (!js.contains(key)) { @@ -82,6 +85,53 @@ bool FileConfiguration::Exists(const std::string &key) const { } return true; } + +void FileConfiguration::PersistNodes(const core::ClusterConfig &clusterConfig) const { + if (!CommUtil::IsFileExists(file_path_)) { + MS_LOG(WARNING) << "The file path:" << file_path_ << " is not exist. create one"; + } + + nlohmann::json persist_js; + persist_js[kRecoveryTotalNodeNum] = clusterConfig.initial_total_node_num; + persist_js[kRecoveryNextWorkerRankId] = clusterConfig.initial_next_worker_rank_id; + persist_js[kRecoveryNextServerRankId] = clusterConfig.initial_next_server_rank_id; + + auto node_infos = clusterConfig.initial_registered_nodes_infos; + for (const auto kvs : node_infos) { + std::unordered_map res; + res["ip"] = kvs.second.ip_; + res["port"] = std::to_string(kvs.second.port_); + res["node_id"] = kvs.second.node_id_; + res["rank_id"] = std::to_string(kvs.second.rank_id_); + res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_); + res["alive"] = CommUtil::BoolToString(kvs.second.is_alive); + persist_js["node_ids"].push_back(res); + } + + std::ofstream output_file(file_path_); + output_file << persist_js.dump(); + + output_file.close(); + MS_LOG(INFO) << "The nodes meta data persist to " << file_path_; +} + +void FileConfiguration::PersistFile(const core::ClusterConfig &clusterConfig) const { + if (!CommUtil::IsFileExists(file_path_)) { + MS_LOG(WARNING) << "The file path:" << file_path_ << " is not exist. create one"; + } + + nlohmann::json persist_js; + persist_js[kRecoveryWorkerNum] = clusterConfig.initial_worker_num; + persist_js[kRecoveryServerNum] = clusterConfig.initial_server_num; + persist_js[kRecoverySchedulerIp] = clusterConfig.scheduler_host; + persist_js[kRecoverySchedulerPort] = clusterConfig.scheduler_port; + + std::ofstream output_file(file_path_); + output_file << persist_js.dump(); + + output_file.close(); + MS_LOG(INFO) << "The meta data persist to " << file_path_; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/file_configuration.h b/mindspore/ccsrc/ps/core/file_configuration.h index 223f7ecceaa..950c86386db 100644 --- a/mindspore/ccsrc/ps/core/file_configuration.h +++ b/mindspore/ccsrc/ps/core/file_configuration.h @@ -25,12 +25,11 @@ #include #include #include +#include -#include "utils/hash_map.h" #include "ps/constants.h" #include "utils/log_adapter.h" #include "ps/core/comm_util.h" -#include "nlohmann/json.hpp" #include "ps/core/configuration.h" namespace mindspore { @@ -58,12 +57,18 @@ class FileConfiguration : public Configuration { std::string GetString(const std::string &key, const std::string &defaultvalue) const override; + std::vector GetVector(const std::string &key) const override; + int64_t GetInt(const std::string &key, int64_t default_value) const override; void Put(const std::string &key, const std::string &value) override; bool Exists(const std::string &key) const override; + void PersistFile(const core::ClusterConfig &clusterConfig) const override; + + void PersistNodes(const core::ClusterConfig &clusterConfig) const override; + private: // The path of the configuration file. std::string file_path_; diff --git a/mindspore/ccsrc/ps/core/follower_scaler.cc b/mindspore/ccsrc/ps/core/follower_scaler.cc index ac33ab2a835..54bcbd51f2d 100644 --- a/mindspore/ccsrc/ps/core/follower_scaler.cc +++ b/mindspore/ccsrc/ps/core/follower_scaler.cc @@ -132,6 +132,7 @@ void FollowerScaler::ProcessBeforeScaleOut() { } scaling_state_ = NodeScaleState::kWaiting; // Notify scheduler that this node is ready for elastic scaling out. + MS_EXCEPTION_IF_NULL(node_); node_->set_ready_for_scale_out(); } @@ -142,6 +143,7 @@ void FollowerScaler::ProcessBeforeScaleIn() { } scaling_state_ = NodeScaleState::kWaiting; // Notify scheduler that this node is ready for elastic scaling in. + MS_EXCEPTION_IF_NULL(node_); node_->set_ready_for_scale_in(); } @@ -153,6 +155,7 @@ void FollowerScaler::ProcessAfterScaleOut() { } scaling_state_ = NodeScaleState::kNormal; // Notify scheduler that scaling out of this node is done. + MS_EXCEPTION_IF_NULL(node_); node_->set_scale_out_done(); } @@ -164,6 +167,7 @@ void FollowerScaler::ProcessAfterScaleIn() { } scaling_state_ = NodeScaleState::kNormal; // Notify scheduler that scaling out of this node is done. + MS_EXCEPTION_IF_NULL(node_); node_->set_scale_in_done(); } diff --git a/mindspore/ccsrc/ps/core/node.cc b/mindspore/ccsrc/ps/core/node.cc index b74e4a26c43..a3b6f3fa9b7 100644 --- a/mindspore/ccsrc/ps/core/node.cc +++ b/mindspore/ccsrc/ps/core/node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ uint16_t Node::BoundPort() const { return node_info_.port_; } std::string Node::BoundIp() const { return node_info_.ip_; } bool Node::WaitForStart(const uint32_t &timeout) { + MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is Waiting for start!"; std::unique_lock lock(wait_start_mutex_); bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] { bool result = this->is_ready_.load(); @@ -79,8 +80,8 @@ bool Node::SendMessageSync(const std::shared_ptr &client, const std:: if (!client->SendMessage(meta, protos, data, size)) { MS_LOG(WARNING) << "Client send message failed."; } - MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) - << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; + MS_LOG(INFO) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_) + << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id; return Wait(request_id, timeout); } @@ -152,15 +153,16 @@ void Node::ProcessSendDataResp(const std::shared_ptr &meta, const P if (size > 0) { size_t dest_size = size; size_t src_size = size; - if (memcpy_s(received_data.get()->data(), dest_size, data, src_size) != EOK) { - MS_LOG(EXCEPTION) << "The memcpy_s error"; + auto ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } } if (it != receive_messages_.end()) { it->second[rank_id] = received_data; } else { - mindspore::HashMap res; - (void)res.emplace(rank_id, received_data); + std::unordered_map res; + (void)res.insert(std::make_pair(rank_id, received_data)); receive_messages_[request_id] = res; } } else { @@ -169,15 +171,16 @@ void Node::ProcessSendDataResp(const std::shared_ptr &meta, const P if (size > 0) { size_t dest_size = size; size_t src_size = size; - if (memcpy_s(received_data.get()->data(), dest_size, data, src_size) != EOK) { - MS_LOG(EXCEPTION) << "The memcpy_s error"; + auto ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")"; } } if (it != workder_receive_messages_.end()) { it->second[rank_id] = received_data; } else { - mindspore::HashMap res; - (void)res.emplace(rank_id, received_data); + std::unordered_map res; + (void)res.insert(std::make_pair(rank_id, received_data)); workder_receive_messages_[request_id] = res; } } @@ -197,7 +200,7 @@ void Node::RunMessageCallback(const uint64_t &request_id) { } message_callbacks_mutex_.lock(); - message_callbacks_.erase(it); + (void)message_callbacks_.erase(it); } } message_callbacks_mutex_.unlock(); diff --git a/mindspore/ccsrc/ps/core/node.h b/mindspore/ccsrc/ps/core/node.h index d0da5291f7b..d6495d1ac5b 100644 --- a/mindspore/ccsrc/ps/core/node.h +++ b/mindspore/ccsrc/ps/core/node.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,13 +24,13 @@ #include #include #include +#include #include #include #include #include #include -#include "utils/hash_map.h" #include "ps/core/cluster_metadata.h" #include "ps/core/cluster_config.h" #include "ps/ps_context.h" @@ -42,8 +42,8 @@ namespace mindspore { namespace ps { namespace core { -constexpr int kTimeoutInSeconds = 180; -constexpr int kCommTimeoutInSeconds = 180; +constexpr int kTimeoutInSeconds = 30; +constexpr int kCommTimeoutInSeconds = 10; class Node { public: Node() @@ -92,9 +92,7 @@ class Node { void RunMessageCallback(const uint64_t &request_id); NodeInfo node_info_; - // Whether the cluster is ready std::atomic is_ready_; - // Whether the cluster is finished. std::atomic is_finish_; std::atomic is_already_stopped_; @@ -108,7 +106,7 @@ class Node { std::mutex finish_mutex_; // the key is: request_id, the value is: - mindspore::HashMap> message_tracker_; + std::unordered_map> message_tracker_; std::mutex message_tracker_mutex_; std::condition_variable message_tracker_cond_; @@ -128,13 +126,13 @@ class Node { std::mutex client_mutex_; // the key is: request_id - mindspore::HashMap message_callbacks_; + std::unordered_map message_callbacks_; std::mutex message_callbacks_mutex_; // the key is: request_id, the value is: - mindspore::HashMap> receive_messages_; + std::unordered_map> receive_messages_; // the key is: request_id, the value is: - mindspore::HashMap> workder_receive_messages_; + std::unordered_map> workder_receive_messages_; std::map, bool> receive_messages_done_; std::mutex receive_messages_mutex_; }; diff --git a/mindspore/ccsrc/ps/core/node_info.h b/mindspore/ccsrc/ps/core/node_info.h index ceafb120b04..4a37cc48116 100644 --- a/mindspore/ccsrc/ps/core/node_info.h +++ b/mindspore/ccsrc/ps/core/node_info.h @@ -32,7 +32,9 @@ enum class ClusterEvent { READY_FOR_SCALE_OUT = 3, READY_FOR_SCALE_IN = 4, CLUSTER_SCALE_OUT_DONE = 5, - CLUSTER_SCALE_IN_DONE = 6 + CLUSTER_SCALE_IN_DONE = 6, + READY_FOR_RECOVERY = 7, + RECOVERY_DONE = 8, }; struct NodeInfo { diff --git a/mindspore/ccsrc/ps/core/node_manager.cc b/mindspore/ccsrc/ps/core/node_manager.cc index 4783b0b055b..9f34e108688 100644 --- a/mindspore/ccsrc/ps/core/node_manager.cc +++ b/mindspore/ccsrc/ps/core/node_manager.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,15 +25,12 @@ void NodeManager::InitNode() { meta_data_ = std::make_unique(PSContext::instance()->cluster_config().initial_worker_num, PSContext::instance()->cluster_config().initial_server_num); MS_EXCEPTION_IF_NULL(meta_data_); - total_node_num_ = UintToInt(initial_total_node_num_); + total_node_num_ = initial_total_node_num_; } -uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const std::shared_ptr &meta) { - MS_EXCEPTION_IF_NULL(meta); - MS_EXCEPTION_IF_NULL(meta_data_); - std::lock_guard lock(assign_rank_id_mutex_); +uint32_t NodeManager::checkIfRankIdExist(const RegisterMessage ®ister_message, + const std::shared_ptr &meta) { uint32_t rank_id = UINT_MAX; - const std::string &node_id = register_message.node_id(); if (registered_nodes_info_.find(node_id) != registered_nodes_info_.end()) { const std::string &new_ip = register_message.ip(); @@ -42,10 +39,48 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const registered_nodes_info_[node_id].is_alive = true; registered_nodes_info_[node_id].ip_ = new_ip; registered_nodes_info_[node_id].port_ = static_cast(new_port); - MS_LOG(INFO) << "The node id: " << node_id << " is already assigned!"; + MS_LOG(INFO) << "The node id: " << node_id << " is already assigned!" + << ", ip: " << register_message.ip() << ", port: " << register_message.port() + << ", rank id: " << rank_id << ", alive: " << registered_nodes_info_[node_id].is_alive + << ", the node_role:" << CommUtil::NodeRoleToString(registered_nodes_info_[node_id].node_role_); return rank_id; } + // This is for scheduler recovery + core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); + std::unordered_map recovery_node_infos = clusterConfig.initial_registered_nodes_infos; + if (recovery_node_infos.find(node_id) != recovery_node_infos.end()) { + const std::string &new_ip = register_message.ip(); + uint32_t new_port = register_message.port(); + rank_id = recovery_node_infos[node_id].rank_id_; + recovery_node_infos[node_id].is_alive = true; + recovery_node_infos[node_id].ip_ = new_ip; + recovery_node_infos[node_id].port_ = static_cast(new_port); + registered_nodes_info_[node_id] = recovery_node_infos[node_id]; + MS_LOG(INFO) << "The node id: " << node_id << " is recovery successful!" + << ", ip: " << recovery_node_infos[node_id].ip_ << ", port: " << recovery_node_infos[node_id].port_ + << ", rank id: " << rank_id << ", alive: " << recovery_node_infos[node_id].is_alive + << ", the node_role:" << CommUtil::NodeRoleToString(recovery_node_infos[node_id].node_role_); + return rank_id; + } + return rank_id; +} + +uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const std::shared_ptr &meta) { + MS_EXCEPTION_IF_NULL(meta); + MS_EXCEPTION_IF_NULL(meta_data_); + std::lock_guard lock(assign_rank_id_mutex_); + uint32_t rank_id = checkIfRankIdExist(register_message, meta); + if (rank_id != UINT_MAX) { + return rank_id; + } + if (total_node_num_ == SizeToInt(registered_nodes_info_.size())) { + MS_LOG(WARNING) << "There are enough nodes registering to scheduler."; + return UINT_MAX; + } + + const std::string &node_id = register_message.node_id(); + // create new rank id if (register_message.role() == NodeRole::SERVER) { const std::string &ip = register_message.ip(); uint32_t port = register_message.port(); @@ -63,23 +98,24 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const rank_id = meta->rank_id(); MS_LOG(INFO) << "Use the old rank id:" << rank_id; } else { - rank_id = IntToUint(++next_server_rank_id_); + rank_id = ++next_server_rank_id_; } } else { registered_nodes_info_.erase((*rank_it).first); } if (rank_id >= meta_data_->server_num) { - MS_LOG(WARNING) << "The rank id is greater than the number of servers:" << meta_data_->server_num; + MS_LOG(ERROR) << "The rank id is greater than the number of servers:" << meta_data_->server_num; rank_id = UINT_MAX; --next_server_rank_id_; + return rank_id; } NodeInfo node_info; node_info.node_role_ = NodeRole::SERVER; node_info.node_id_ = node_id; node_info.rank_id_ = rank_id; node_info.ip_ = ip; - node_info.port_ = static_cast(port); + node_info.port_ = port; node_info.is_alive = true; registered_nodes_info_[node_id] = node_info; MS_LOG(INFO) << "The server node id:" << node_id << ",node ip: " << node_info.ip_ << ",node port:" << port @@ -102,26 +138,28 @@ uint32_t NodeManager::NextRankId(const RegisterMessage ®ister_message, const rank_id = meta->rank_id(); MS_LOG(INFO) << "Use the old rank id:" << rank_id; } else { - rank_id = IntToUint(++next_worker_rank_id_); + rank_id = ++next_worker_rank_id_; } } else { registered_nodes_info_.erase((*worker_rank_it).first); } if (rank_id >= meta_data_->worker_num) { - MS_LOG(WARNING) << "The rank id is greater than the number of workers:" << meta_data_->worker_num; + MS_LOG(ERROR) << "The rank id is greater than the number of workers:" << meta_data_->worker_num; rank_id = UINT_MAX; --next_worker_rank_id_; + return rank_id; } NodeInfo node_info; node_info.node_role_ = NodeRole::WORKER; node_info.node_id_ = node_id; node_info.rank_id_ = rank_id; node_info.ip_ = ip; - node_info.port_ = static_cast(port); + node_info.port_ = port; node_info.is_alive = true; registered_nodes_info_[node_id] = node_info; - MS_LOG(INFO) << "The worker node id:" << node_id << " assign rank id:" << rank_id; + MS_LOG(INFO) << "The worker node id:" << node_id << ", node ip: " << node_info.ip_ << ", node port:" << port + << " assign rank id:" << rank_id; } return rank_id; } @@ -183,6 +221,21 @@ void NodeManager::UpdateCluster() { (void)heartbeats_.erase(iter->first); finish_nodes_id_.insert(iter->first); } + if (onPersist) { + onPersist(); + } + } else if (SizeToInt(heartbeats_.size()) == total_node_num_) { + if (cluster_state_ == ClusterState::NODE_TIMEOUT) { + for (auto it = registered_nodes_info_.begin(); it != registered_nodes_info_.end(); ++it) { + if (registered_nodes_info_.count(it->first)) { + registered_nodes_info_[it->first].is_alive = true; + } + } + if (onPersist) { + onPersist(); + } + UpdateClusterState(ClusterState::CLUSTER_READY); + } } // 2. update cluster finish state @@ -223,12 +276,16 @@ bool NodeManager::IsAllNodesScaleOutDone() const { bool NodeManager::IsAllNodesScaleInDone() const { return SizeToInt(scale_in_done_nodes_id_.size()) == total_node_num_; } -const mindspore::HashMap &NodeManager::nodes_info() const { return nodes_info_; } +const std::unordered_map &NodeManager::nodes_info() const { return nodes_info_; } -const mindspore::HashMap &NodeManager::registered_nodes_info() const { +const std::unordered_map &NodeManager::registered_nodes_info() const { return registered_nodes_info_; } +void NodeManager::set_registered_nodes_info(const std::unordered_map registered_nodes_info) { + this->registered_nodes_info_ = registered_nodes_info; +} + void NodeManager::UpdateNodesInfo() { MS_LOG(INFO) << "Update nodes info."; nodes_info_.clear(); @@ -271,9 +328,19 @@ void NodeManager::ResetMetadata(const std::vector &scale_in_nodes) MS_LOG(INFO) << "The next server rank id:" << next_server_rank_id_; } registered_nodes_info_.clear(); + ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); + clusterConfig.initial_registered_nodes_infos.clear(); heartbeats_.clear(); } +void NodeManager::SaveRecoveryRankId(const NodeInfo &info) { + if (info.node_role_ == NodeRole::SERVER) { + recovery_server_rank_id_.push_back(info.rank_id_); + } else if (info.node_role_ == NodeRole::WORKER) { + recovery_worker_rank_id_.push_back(info.rank_id_); + } +} + bool NodeManager::IsWorkerOrServer0() { bool res = std::any_of(registered_nodes_info_.begin(), registered_nodes_info_.end(), [](auto item) { if (item.second.node_role_ == NodeRole::WORKER && item.second.is_alive == false) { @@ -308,6 +375,18 @@ void NodeManager::set_server_num(const int32_t &server_num) { meta_data_->server int32_t NodeManager::worker_num() const { return UintToInt(meta_data_->worker_num); } int32_t NodeManager::server_num() const { return UintToInt(meta_data_->server_num); } + +int32_t NodeManager::next_worker_rank_id() const { return next_worker_rank_id_.load(); } + +int32_t NodeManager::next_server_rank_id() const { return next_server_rank_id_.load(); } + +void NodeManager::set_next_worker_rank_id(const int32_t &next_worker_rank_id) { + this->next_worker_rank_id_ = next_worker_rank_id; +} +void NodeManager::set_next_server_rank_id(const int32_t &next_server_rank_id) { + this->next_server_rank_id_ = next_server_rank_id; +} +void NodeManager::setPersistCallback(const OnPersist &onPersist) { this->onPersist = onPersist; } } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/node_manager.h b/mindspore/ccsrc/ps/core/node_manager.h index 5b787f61139..a0658948277 100644 --- a/mindspore/ccsrc/ps/core/node_manager.h +++ b/mindspore/ccsrc/ps/core/node_manager.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,13 +26,13 @@ #include #include #include +#include #include #include +#include #include #include -#include "utils/hash_map.h" -#include "utils/hash_set.h" #include "ps/core/node.h" #include "utils/log_adapter.h" #include "utils/convert_utils_base.h" @@ -53,10 +53,11 @@ class NodeManager { node_state_(NodeState::NODE_STARTING), cluster_state_(ClusterState::ClUSTER_STARTING) {} virtual ~NodeManager() = default; - + using OnPersist = std::function; // When initializing nodes, the initial number of nodes will be assigned to the total number of nodes. void InitNode(); uint32_t NextRankId(const RegisterMessage ®ister_message, const std::shared_ptr &meta); + uint32_t checkIfRankIdExist(const RegisterMessage ®ister_message, const std::shared_ptr &meta); void UpdateHeartbeat(const std::string &node_id); std::vector FetchServersMeta(); @@ -86,8 +87,8 @@ class NodeManager { // nodes and Determine whether the nodes are equal to total_node_num_. bool IsAllNodesScaleInDone() const; - const mindspore::HashMap &nodes_info() const; - const mindspore::HashMap ®istered_nodes_info() const; + const std::unordered_map &nodes_info() const; + const std::unordered_map ®istered_nodes_info() const; // After all the nodes are registered successfully, the nodes info can be updated. void UpdateNodesInfo(); @@ -98,6 +99,9 @@ class NodeManager { int32_t worker_num() const; int32_t server_num() const; + int32_t next_worker_rank_id() const; + int32_t next_server_rank_id() const; + void UpdateNodeState(const NodeState &state); void UpdateClusterState(const ClusterState &state); NodeState GetNodeState(); @@ -107,11 +111,18 @@ class NodeManager { // will re-register. void ResetMetadata(const std::vector &scale_in_nodes = {}); + void SaveRecoveryRankId(const NodeInfo &info); + bool IsWorkerOrServer0(); // Determine whether the node id has been registered. bool IsNodeRegistered(const std::string &node_id); + void set_registered_nodes_info(const std::unordered_map registered_nodes_info); + void set_next_worker_rank_id(const int32_t &next_worker_rank_id); + void set_next_server_rank_id(const int32_t &next_server_rank_id); + void setPersistCallback(const OnPersist &onPersist); + private: std::mutex node_mutex_; std::mutex cluster_mutex_; @@ -124,28 +135,33 @@ class NodeManager { std::atomic next_server_rank_id_; // Whenever a node is registered, it will be stored in this map. - mindspore::HashMap registered_nodes_info_; + std::unordered_map registered_nodes_info_; // When all nodes are registered successfully, then all nodes info will be stored in this map. In other words, the // nodes_info_ is a snapshot of the registered_nodes_info_. - mindspore::HashMap nodes_info_; + std::unordered_map nodes_info_; std::mutex assign_rank_id_mutex_; std::mutex heartbeat_mutex_; - mindspore::HashMap heartbeats_; + std::unordered_map heartbeats_; // timeout nodes - mindspore::HashMap timeout_nodes_info_; - mindspore::HashSet finish_nodes_id_; + std::unordered_map timeout_nodes_info_; + std::unordered_set finish_nodes_id_; // The scheduler aggregates scale_out_done messages from workers/servers - mindspore::HashSet scale_out_done_nodes_id_; + std::unordered_set scale_out_done_nodes_id_; // The scheduler aggregates scale_in_done messages from workers/servers - mindspore::HashSet scale_in_done_nodes_id_; + std::unordered_set scale_in_done_nodes_id_; // Cluster metadata information can be dynamically changed std::unique_ptr meta_data_; NodeState node_state_; ClusterState cluster_state_; + + std::deque recovery_worker_rank_id_; + std::deque recovery_server_rank_id_; + + OnPersist onPersist; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/node_recovery.cc b/mindspore/ccsrc/ps/core/node_recovery.cc index e83bb15ae4c..bc3e311978d 100644 --- a/mindspore/ccsrc/ps/core/node_recovery.cc +++ b/mindspore/ccsrc/ps/core/node_recovery.cc @@ -23,6 +23,7 @@ bool NodeRecovery::Recover() { if (recovery_storage_ == nullptr) { return false; } + // 1. recover worker num MS_ERROR_IF_NULL_W_RET_VAL(node_, false); if (recovery_storage_->Exists(kRecoveryWorkerNum)) { @@ -42,8 +43,7 @@ bool NodeRecovery::Recover() { // 3. recover scheduler ip if (recovery_storage_->Exists(kRecoverySchedulerIp)) { - std::string scheduler_ip = recovery_storage_->Get(kRecoverySchedulerIp, ""); - node_->set_scheduler_ip(scheduler_ip); + node_->set_scheduler_ip(recovery_storage_->GetString(kRecoverySchedulerIp, "")); } else { node_->set_scheduler_ip(PSContext::instance()->cluster_config().scheduler_host); } diff --git a/mindspore/ccsrc/ps/core/protos/comm.proto b/mindspore/ccsrc/ps/core/protos/comm.proto index 5f6ad88d6ec..0304bfc97e7 100644 --- a/mindspore/ccsrc/ps/core/protos/comm.proto +++ b/mindspore/ccsrc/ps/core/protos/comm.proto @@ -39,6 +39,8 @@ enum NodeCommand { SCALE_IN_DONE = 11; // This command is used to send user defined event. SEND_EVENT = 12; + // This command is used to send scheduler recovery event. + SCHEDULER_RECOVERY = 13; } enum NodeRole { @@ -73,6 +75,7 @@ message RegisterMessage { message RegisterRespMessage { string node_id = 1; + uint32 rank_id = 2; } message HeartbeatMessage { @@ -96,6 +99,7 @@ enum ClusterState { CLUSTER_NEW_INSTANCE = 6; CLUSTER_ENABLE_FLS = 7; CLUSTER_DISABLE_FLS = 8; + CLUSTER_SCHEDULER_RECOVERY = 9; } message HeartbeatRespMessage { diff --git a/mindspore/ccsrc/ps/core/protos/fl.proto b/mindspore/ccsrc/ps/core/protos/fl.proto index ec5dde2c9b5..93c46e12661 100644 --- a/mindspore/ccsrc/ps/core/protos/fl.proto +++ b/mindspore/ccsrc/ps/core/protos/fl.proto @@ -139,6 +139,24 @@ message Prime { bytes prime = 1; } +message PairClientListSign { + string fl_id = 1; + bytes signature = 2; +} + +message ClientListSign { + map client_list_sign = 1; +} + +message PairKeyAttestation { + string fl_id = 1; + string certificate = 2; +} + +message KeyAttestation { + map key_attestations = 1; +} + message PBMetadata { oneof value { DeviceMeta device_meta = 1; @@ -159,6 +177,12 @@ message PBMetadata { ClientNoises client_noises = 11; Prime prime = 12; + + ClientListSign client_list_sign = 13; + PairClientListSign pair_client_list_sign = 14; + + KeyAttestation key_attestation = 15; + PairKeyAttestation pair_key_attestation = 16; } } @@ -178,8 +202,8 @@ message SyncIterationResponse { } message PrepareForNextIterRequest { - bool is_last_iter_valid = 1; - string reason = 2; + bool is_last_iter_valid = 2; + string reason = 3; } message PrepareForNextIterResponse { @@ -214,3 +238,7 @@ message EndLastIterRequest { message EndLastIterResponse { string result = 1; } + +message SyncAfterRecover { + uint64 current_iter_num = 1; +} diff --git a/mindspore/ccsrc/ps/core/recovery_base.cc b/mindspore/ccsrc/ps/core/recovery_base.cc index b405a1df768..dca6edcf2d2 100644 --- a/mindspore/ccsrc/ps/core/recovery_base.cc +++ b/mindspore/ccsrc/ps/core/recovery_base.cc @@ -19,18 +19,19 @@ namespace mindspore { namespace ps { namespace core { -void RecoveryBase::Initialize(const std::string &config_json) { +bool RecoveryBase::Initialize(const std::string &config_json) { nlohmann::json recovery_config; try { recovery_config = nlohmann::json::parse(config_json); } catch (nlohmann::json::exception &e) { MS_LOG(ERROR) << "Parse the json:" << config_json; + return false; } MS_LOG(INFO) << "The node is support recovery."; if (!recovery_config.contains(kStoreType)) { MS_LOG(WARNING) << "The " << kStoreType << " is not existed."; - return; + return false; } std::string storage_file_path = ""; std::string type = recovery_config.at(kStoreType).dump(); @@ -39,23 +40,67 @@ void RecoveryBase::Initialize(const std::string &config_json) { if (!recovery_config.contains(kStoreFilePath)) { MS_LOG(WARNING) << "The " << kStoreFilePath << " is not existed."; - return; + return false; } storage_file_path = recovery_config.at(kStoreFilePath); if (storage_file_path == "") { MS_LOG(EXCEPTION) << "If the scheduler support recovery, and if the persistent storage is a file, the path of " "the file must be configured"; } - recovery_storage_ = std::make_unique(storage_file_path); MS_EXCEPTION_IF_NULL(recovery_storage_); - if (!recovery_storage_->Initialize()) { - MS_LOG(INFO) << "The storage file path " << storage_file_path << " is empty."; + MS_LOG(WARNING) << "The storage file path " << storage_file_path << " is empty."; } } MS_LOG(INFO) << "The storage type is:" << storage_type_ << ", the storage file path is:" << storage_file_path; + return true; +} + +bool RecoveryBase::InitializeNodes(const std::string &config_json) { + nlohmann::json recovery_config; + try { + recovery_config = nlohmann::json::parse(config_json); + } catch (nlohmann::json::exception &e) { + MS_LOG(ERROR) << "Parse the json:" << config_json; + return false; + } + + if (!recovery_config.contains(kSchedulerStoreFilePath)) { + MS_LOG(WARNING) << "The " << kStoreFilePath << " is not existed."; + return false; + } + + // this is only for scheduler + std::string scheduler_storage_file_path = recovery_config.at(kSchedulerStoreFilePath); + if (scheduler_storage_file_path == "") { + MS_LOG(WARNING) << "scheduler storage file path is not exist!"; + } + scheduler_recovery_storage_ = std::make_unique(scheduler_storage_file_path); + MS_EXCEPTION_IF_NULL(scheduler_recovery_storage_); + if (!scheduler_recovery_storage_->Initialize()) { + MS_LOG(WARNING) << "The scheduler storage file path " << scheduler_storage_file_path << " is empty."; + } + + MS_LOG(INFO) << "the scheduler storage file path is:" << scheduler_storage_file_path; + return true; +} + +void RecoveryBase::Persist(const core::ClusterConfig &clusterConfig) const { + if (recovery_storage_ == nullptr) { + MS_LOG(WARNING) << "recovery storage is null, so don't persist meta data"; + return; + } + recovery_storage_->PersistFile(clusterConfig); +} + +void RecoveryBase::PersistNodesInfo(const core::ClusterConfig &clusterConfig) const { + if (scheduler_recovery_storage_ == nullptr) { + MS_LOG(WARNING) << "scheduler recovery storage is null, so don't persist nodes meta data"; + return; + } + scheduler_recovery_storage_->PersistNodes(clusterConfig); } } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/recovery_base.h b/mindspore/ccsrc/ps/core/recovery_base.h index 870fc25cd87..f90e8b698b1 100644 --- a/mindspore/ccsrc/ps/core/recovery_base.h +++ b/mindspore/ccsrc/ps/core/recovery_base.h @@ -41,15 +41,27 @@ class RecoveryBase { virtual ~RecoveryBase() = default; // Initialize the recovery configuration item and get the storage type of recovery. - virtual void Initialize(const std::string &json_config); + virtual bool Initialize(const std::string &json_config); + + // Initialize the recovery configuration item and get the storage type of recovery. + virtual bool InitializeNodes(const std::string &json_config); // The node needs to recover metadata information when it starts. virtual bool Recover() = 0; + // Persist metadata to storage. + virtual void Persist(const core::ClusterConfig &clusterConfig) const; + + // Persist metadata to storage. + virtual void PersistNodesInfo(const core::ClusterConfig &clusterConfig) const; + protected: // Persistent storage used to save metadata. std::unique_ptr recovery_storage_; + // Persistent storage used to save server nodes metadata. + std::unique_ptr scheduler_recovery_storage_; + // Storage type for recovery,Currently only supports storage of file types StorageType storage_type_; }; diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index 2e16e0edb83..d6d124d9f07 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,7 @@ */ #include "ps/core/scheduler_node.h" -#include -#include +#include "ps/core/scheduler_recovery.h" namespace mindspore { namespace ps { @@ -30,13 +29,28 @@ SchedulerNode::~SchedulerNode() { bool SchedulerNode::Start(const uint32_t &timeout) { MS_LOG(INFO) << "[Scheduler start]: 1. Begin to start scheduler node!"; + config_ = std::make_unique(PSContext::instance()->config_file_path()); + MS_EXCEPTION_IF_NULL(config_); + if (!config_->Initialize()) { + MS_LOG(INFO) << "The config file is empty, then init node by context."; + InitNodeMetaData(); + } else { + if (!RecoverScheduler()) { + MS_LOG(WARNING) << "Recover the server node is failed."; + } + } + if (PSContext::instance()->scheduler_manage_port() != 0) { - MS_LOG(WARNING) << "Start the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip() + MS_LOG(WARNING) << "Start the restful scheduler http service, the ip is 127.0.0.1 " << ", the port:" << PSContext::instance()->scheduler_manage_port(); StartRestfulServer(kLocalIp, PSContext::instance()->scheduler_manage_port(), 1); } Initialize(); StartUpdateClusterStateTimer(); + RunRecovery(); + if (is_worker_timeout_) { + BroadcastTimeoutEvent(); + } if (!WaitForStart(timeout)) { MS_LOG(ERROR) << "Start Scheduler node timeout!"; return false; @@ -48,6 +62,67 @@ bool SchedulerNode::Start(const uint32_t &timeout) { return true; } +void SchedulerNode::RunRecovery() { + core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); + // create tcp client to myself in case of event dispatch failed when Send reconnect msg to server failed + client_to_scheduler_ = + std::make_shared(clusterConfig.scheduler_host, clusterConfig.scheduler_port, config_.get()); + MS_EXCEPTION_IF_NULL(client_to_scheduler_); + client_to_scheduler_->Init(); + client_thread_ = std::make_unique([&]() { + MS_LOG(INFO) << "The node start a tcp client!"; + client_to_scheduler_->Start(); + }); + MS_EXCEPTION_IF_NULL(client_thread_); + + auto initial_node_infos = clusterConfig.initial_registered_nodes_infos; + if (initial_node_infos.empty()) { + MS_LOG(WARNING) << "There is no registered nodes in scheduler!"; + return; + } + MS_LOG(INFO) << "The scheduler start run recovery!"; + int worker_num = clusterConfig.initial_worker_num; + int server_num = clusterConfig.initial_server_num; + + node_manager_.set_worker_num(worker_num); + node_manager_.set_server_num(server_num); + node_manager_.set_next_worker_rank_id(clusterConfig.initial_next_worker_rank_id); + node_manager_.set_next_server_rank_id(clusterConfig.initial_next_server_rank_id); + node_manager_.set_total_node_num(clusterConfig.initial_total_node_num); + + for (const auto kvs : initial_node_infos) { + auto client = std::make_shared(kvs.second.ip_, kvs.second.port_, config_.get()); + client->SetMessageCallback( + [&](const std::shared_ptr &meta, const Protos &protos, const void *data, size_t size) { + MS_LOG(INFO) << "received the response. "; + NotifyMessageArrival(meta); + }); + client->Init(); + MS_EXCEPTION_IF_NULL(client); + + auto message_meta = std::make_shared(); + MS_EXCEPTION_IF_NULL(message_meta); + message_meta->set_cmd(NodeCommand::SCHEDULER_RECOVERY); + + int rank_id = kvs.second.rank_id_; + SendMetadataMessage scheduler_recovery_message; + scheduler_recovery_message.set_worker_num(worker_num); + scheduler_recovery_message.set_server_num(server_num); + scheduler_recovery_message.set_rank_id(rank_id); + if (!SendMessageSync(client, message_meta, Protos::PROTOBUF, scheduler_recovery_message.SerializeAsString().data(), + scheduler_recovery_message.ByteSizeLong())) { + if (kvs.second.node_role_ == NodeRole::WORKER) { + is_worker_timeout_ = true; + break; + } + MS_LOG(WARNING) << "Scheduler send recovery msg to " << kvs.first << " timeout!"; + } else { + MS_LOG(INFO) << "Scheduler send recovery msg to " << kvs.first << " successful."; + } + } + MS_LOG(INFO) << "Scheduler recovery finish."; +} + void SchedulerNode::ProcessHeartbeat(const std::shared_ptr &server, const std::shared_ptr &conn, const std::shared_ptr &meta, const void *data, size_t size) { @@ -59,6 +134,7 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr &server, CHECK_RETURN_TYPE(heartbeat_message.ParseFromArray(data, SizeToInt(size))); node_manager_.UpdateHeartbeat(heartbeat_message.node_id()); + MS_LOG(DEBUG) << "The scheduler get a heartbeat from node id :" << heartbeat_message.node_id(); HeartbeatRespMessage heartbeat_resp_message; @@ -78,11 +154,6 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr &server, } void SchedulerNode::Initialize() { - config_ = std::make_unique(PSContext::instance()->config_file_path()); - MS_EXCEPTION_IF_NULL(config_); - if (!config_->Initialize()) { - MS_LOG(INFO) << "The config file is empty."; - } InitCommandHandler(); CreateTcpServer(); is_already_stopped_ = false; @@ -95,6 +166,7 @@ void SchedulerNode::Initialize() { if (node_info_.node_id_.empty()) { node_info_.node_id_ = CommUtil::GenerateUUID(); } + node_info_.rank_id_ = 0; node_info_.node_role_ = NodeRole::SCHEDULER; leader_scaler_ = std::make_unique(this); MS_EXCEPTION_IF_NULL(leader_scaler_); @@ -118,6 +190,7 @@ void SchedulerNode::CreateTcpServer() { std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host; uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port; + MS_LOG(INFO) << "scheduler ip: " << scheduler_host << ", scheduler ip: " << scheduler_port; server_ = std::make_shared(scheduler_host, scheduler_port, config_.get()); MS_EXCEPTION_IF_NULL(server_); server_->SetMessageCallback([&](const std::shared_ptr &conn, const std::shared_ptr &meta, @@ -129,6 +202,18 @@ void SchedulerNode::CreateTcpServer() { (this->*handler_ptr)(server_, conn, meta, data, size); }); + const auto client_disconn = [&](const TcpServer &, const TcpConnection &conn) { + int fd = conn.GetFd(); + if (register_connection_fd_.count(fd) <= 0) { + return; + } + MS_LOG(WARNING) << "remove client fd:" << fd << ", remove client id:" << register_connection_fd_[fd]; + register_connection_fd_.erase(fd); + MS_LOG(WARNING) << "Register node number is:" << register_connection_fd_.size() + << ", total node num is:" << node_manager_.total_node_num() + << ", scale in node size is: " << scale_in_node_ids_.size(); + }; + server_->SetServerCallback(nullptr, client_disconn, nullptr); server_->Init(); scheduler_thread_ = std::make_unique([this]() { @@ -147,9 +232,7 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr &server, MS_EXCEPTION_IF_NULL(data); RegisterMessage register_message; CHECK_RETURN_TYPE(register_message.ParseFromArray(data, SizeToInt(size))); - const std::string &node_id = register_message.node_id(); - node_manager_.UpdateHeartbeat(node_id); MS_LOG(INFO) << "The node id:" << node_id << " is registering to scheduler."; client_mutex_.lock(); @@ -164,16 +247,26 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr &server, // assign worker node and server node rank id uint32_t rank_id = node_manager_.NextRankId(register_message, meta); if (rank_id == UINT32_MAX) { - MS_LOG(WARNING) << "The rank id is wrong!"; + MS_LOG(ERROR) << "The rank id is wrong, return register rejected message!"; + RegisterRespMessage register_rejected_message; + register_rejected_message.set_node_id(node_id); + register_rejected_message.set_rank_id(rank_id); + if (!server->SendMessage(conn, meta, Protos::PROTOBUF, register_rejected_message.SerializeAsString().data(), + register_rejected_message.ByteSizeLong())) { + MS_LOG(WARNING) << "Server response message failed."; + } + return; } + node_manager_.UpdateHeartbeat(node_id); RegisterRespMessage register_resp_message; register_resp_message.set_node_id(node_id); - + register_resp_message.set_rank_id(rank_id); if (!server->SendMessage(conn, meta, Protos::PROTOBUF, register_resp_message.SerializeAsString().data(), register_resp_message.ByteSizeLong())) { MS_LOG(WARNING) << "Server response message failed."; } + SetRegisterConnectionFd(conn, node_id); if (node_manager_.IsAllNodesRegistered()) { is_ready_ = true; @@ -199,6 +292,7 @@ void SchedulerNode::ProcessRegister(const std::shared_ptr &server, node_manager_.UpdateHeartbeat(kvs.first); } node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); + PersistMetaData(); wait_start_cond_.notify_all(); } } @@ -284,6 +378,7 @@ void SchedulerNode::ProcessScaleOutDone(const std::shared_ptr &server } is_ready_ = true; node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); + PersistMetaData(); } } @@ -312,6 +407,7 @@ void SchedulerNode::ProcessScaleInDone(const std::shared_ptr &server, } is_ready_ = true; node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); + PersistMetaData(); } } @@ -440,6 +536,7 @@ void SchedulerNode::SendEvent(const std::shared_ptr &client, const ui void SchedulerNode::StartUpdateClusterStateTimer() { MS_LOG(INFO) << "[Scheduler start]: 3. The scheduler start a heartbeat timer!"; + node_manager_.setPersistCallback([&]() { PersistMetaData(); }); update_state_thread_ = std::make_unique([&]() { auto start_time = std::chrono::steady_clock::now(); while (!is_finish_.load()) { @@ -472,6 +569,7 @@ const std::shared_ptr &SchedulerNode::GetOrCreateClient(const NodeInf } std::string ip = node_info.ip_; uint16_t port = node_info.port_; + MS_LOG(DEBUG) << "ip:" << ip << ", port:" << port << ", node id:" << node_info.node_id_; auto client = std::make_shared(ip, port, config_.get()); MS_EXCEPTION_IF_NULL(client); client->SetMessageCallback( @@ -487,15 +585,6 @@ const std::shared_ptr &SchedulerNode::GetOrCreateClient(const NodeInf NotifyMessageArrival(meta); }); client->Init(); - if (is_client_started_ == false) { - is_client_started_ = true; - client_thread_ = std::make_unique([&]() { - MS_LOG(INFO) << "The node start a tcp client!"; - client->Start(); - }); - MS_EXCEPTION_IF_NULL(client_thread_); - } - connected_nodes_[node_info.node_id_] = client; return connected_nodes_[node_info.node_id_]; } @@ -524,7 +613,7 @@ bool SchedulerNode::Stop() { is_ready_ = true; } if (PSContext::instance()->scheduler_manage_port() != 0) { - MS_LOG(WARNING) << "Stop the scheduler http service, the ip:" << PSContext::instance()->scheduler_ip() + MS_LOG(WARNING) << "Stop the restful scheduler http service, the ip is 127.0.0.1 " << ", the port:" << PSContext::instance()->scheduler_manage_port(); StopRestfulServer(); } @@ -532,7 +621,7 @@ bool SchedulerNode::Stop() { } bool SchedulerNode::Finish(const uint32_t &) { - MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to finish scheduler node!"; + MS_LOG(INFO) << "[Scheduler finish]: 1. Begin to listen finish scheduler node!"; std::unique_lock lock(wait_finish_mutex_); wait_finish_cond_.wait(lock, [this] { if (this->is_finish_.load()) { @@ -543,6 +632,45 @@ bool SchedulerNode::Finish(const uint32_t &) { return true; } +void SchedulerNode::ProcessScaleoutRollback(const std::shared_ptr &resp) { + MS_EXCEPTION_IF_NULL(resp); + RequestProcessResult status(RequestProcessResultCode::kSuccess); + if (node_manager_.GetClusterState() != ClusterState::CLUSTER_SCALE_OUT) { + std::string message = "The cluster state is not CLUSTER_SCALE_OUT, does not need to rollback."; + ERROR_STATUS(status, RequestProcessResultCode::kSystemError, message); + resp->ErrorResponse(HTTP_BADREQUEST, status); + return; + } + // set the last worker num and last server num + ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); + node_manager_.set_worker_num(clusterConfig.initial_worker_num); + node_manager_.set_server_num(clusterConfig.initial_server_num); + node_manager_.set_total_node_num(clusterConfig.initial_total_node_num); + + MS_LOG(INFO) << "After scale out rollback, the last worker num:" << clusterConfig.initial_worker_num + << ", the last server num:" << clusterConfig.initial_server_num; + + auto node_infos = node_manager_.nodes_info(); + node_manager_.ResetMetadata(); + for (const auto &kvs : node_infos) { + auto client = GetOrCreateClient(kvs.second); + MS_EXCEPTION_IF_NULL(client); + MS_EXCEPTION_IF_NULL(leader_scaler_); + leader_scaler_->ScaleOutAsync(client, node_manager_); + } + + MS_LOG(INFO) << "Scheduler send scale out rollback successful."; + node_manager_.UpdateClusterState(ClusterState::CLUSTER_SCALE_OUT); + nlohmann::json js; + js["message"] = "Cluster scale out rollback success."; + js["code"] = kSuccessCode; + resp->AddRespString(js.dump()); + resp->AddRespHeadParam("Content-Type", "application/json"); + + resp->SetRespCode(HTTP_OK); + resp->SendResponse(); +} + void SchedulerNode::ProcessScaleOut(const std::shared_ptr &resp) { MS_EXCEPTION_IF_NULL(resp); RequestProcessResult status(RequestProcessResultCode::kSuccess); @@ -595,6 +723,7 @@ void SchedulerNode::ProcessScaleOut(const std::shared_ptr &r nlohmann::json js; js["message"] = "Cluster begin to scale out."; + js["code"] = kSuccessCode; resp->AddRespString(js.dump()); resp->AddRespHeadParam("Content-Type", "application/json"); @@ -637,7 +766,7 @@ void SchedulerNode::ProcessScaleIn(const std::shared_ptr &re MS_LOG(WARNING) << "The scale in node ids:" << scale_in_node_ids_; - mindspore::HashMap scale_in_nodes; + std::unordered_map scale_in_nodes; int32_t scale_worker_num = 0; int32_t scale_server_num = 0; @@ -676,6 +805,7 @@ void SchedulerNode::ProcessScaleIn(const std::shared_ptr &re nlohmann::json js; js["message"] = "Cluster begin to scale in."; + js["code"] = kSuccessCode; resp->AddRespString(js.dump()); resp->AddRespHeadParam("Content-Type", "application/json"); @@ -705,15 +835,24 @@ void SchedulerNode::ProcessGetNodesInfo(const std::shared_ptr res; - res["node_id"] = kvs.second.node_id_; - res["rank_id"] = std::to_string(kvs.second.rank_id_); + res["nodeId"] = kvs.second.node_id_; + res["rankId"] = std::to_string(kvs.second.rank_id_); res["role"] = CommUtil::NodeRoleToString(kvs.second.node_role_); - js["node_ids"].push_back(std::move(res)); + res["alive"] = kvs.second.is_alive ? "true" : "false"; + js["nodeIds"].push_back(res); } + std::unordered_map scheduler_info; + scheduler_info["nodeId"] = node_info_.node_id_; + scheduler_info["rankId"] = std::to_string(node_info_.rank_id_); + scheduler_info["role"] = CommUtil::NodeRoleToString(node_info_.node_role_); + scheduler_info["alive"] = "true"; + js["nodeIds"].push_back(scheduler_info); + resp->AddRespString(js.dump()); resp->AddRespHeadParam("Content-Type", "application/json"); @@ -734,6 +873,7 @@ void SchedulerNode::ProcessGetClusterState(const std::shared_ptrAddRespString(js.dump()); resp->AddRespHeadParam("Content-Type", "application/json"); @@ -765,12 +905,12 @@ void SchedulerNode::ProcessNewInstance(const std::shared_ptr uint64_t request_id = AddMessageTrack(node_manager_.server_num()); - mindspore::HashMap outputs; + std::unordered_map outputs; set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); outputs = receive_messages_[request_id]; - receive_messages_.erase(request_id); + (void)receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); @@ -793,7 +933,8 @@ void SchedulerNode::ProcessNewInstance(const std::shared_ptr node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); nlohmann::json js; - js["message"] = "Start update flPlan successful."; + js["message"] = "Start new instance successful."; + js["code"] = kSuccessCode; for (const auto &output : outputs) { std::string data = std::string(reinterpret_cast(output.second->data()), output.second->size()); js["result"][output.first] = data; @@ -819,12 +960,12 @@ void SchedulerNode::ProcessQueryInstance(const std::shared_ptr outputs; + std::unordered_map outputs; set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); outputs = receive_messages_[request_id]; - receive_messages_.erase(request_id); + (void)receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); @@ -845,10 +986,13 @@ void SchedulerNode::ProcessQueryInstance(const std::shared_ptr(output.second->data()), output.second->size()); - js["result"][output.first] = data; + nlohmann::json dataJson = nlohmann::json::parse(data); + js["result"] = dataJson; + break; } resp->AddRespString(js.dump()); @@ -873,12 +1017,12 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr & uint64_t request_id = AddMessageTrack(node_manager_.server_num()); - mindspore::HashMap outputs; + std::unordered_map outputs; set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); outputs = receive_messages_[request_id]; - receive_messages_.erase(request_id); + (void)receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); @@ -902,6 +1046,7 @@ void SchedulerNode::ProcessEnableFLS(const std::shared_ptr & node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); nlohmann::json js; js["message"] = "start enabling FL-Server successful."; + js["code"] = kSuccessCode; for (const auto &output : outputs) { std::string data = std::string(reinterpret_cast(output.second->data()), output.second->size()); js["result"][output.first] = data; @@ -929,12 +1074,12 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr uint64_t request_id = AddMessageTrack(node_manager_.server_num()); - mindspore::HashMap outputs; + std::unordered_map outputs; set_message_callback(request_id, [&]() { receive_messages_mutex_.lock(); outputs = receive_messages_[request_id]; - receive_messages_.erase(request_id); + (void)receive_messages_.erase(request_id); receive_messages_mutex_.unlock(); }); @@ -958,6 +1103,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr node_manager_.UpdateClusterState(ClusterState::CLUSTER_READY); nlohmann::json js; js["message"] = "start disabling FL-Server successful."; + js["code"] = kSuccessCode; for (const auto &output : outputs) { std::string data = std::string(reinterpret_cast(output.second->data()), output.second->size()); js["result"][output.first] = data; @@ -972,7 +1118,7 @@ void SchedulerNode::ProcessDisableFLS(const std::shared_ptr RequestProcessResult SchedulerNode::CheckIfClusterReady() { RequestProcessResult result(RequestProcessResultCode::kSuccess); - if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY) { + if (node_manager_.GetClusterState() != ClusterState::CLUSTER_READY || CheckIfNodeDisconnected()) { std::string message = "The cluster is not ready."; ERROR_STATUS(result, RequestProcessResultCode::kSystemError, message); return result; @@ -1039,19 +1185,23 @@ void SchedulerNode::StartRestfulServer(const std::string &address, std::uint16_t OnRequestReceive new_instance = std::bind(&SchedulerNode::ProcessNewInstance, this, std::placeholders::_1); callbacks_["/newInstance"] = new_instance; - http_server_->RegisterRoute("/newInstance", &callbacks_["/newInstance"]); + (void)http_server_->RegisterRoute("/newInstance", &callbacks_["/newInstance"]); OnRequestReceive query_instance = std::bind(&SchedulerNode::ProcessQueryInstance, this, std::placeholders::_1); callbacks_["/queryInstance"] = query_instance; - http_server_->RegisterRoute("/queryInstance", &callbacks_["/queryInstance"]); + (void)http_server_->RegisterRoute("/queryInstance", &callbacks_["/queryInstance"]); OnRequestReceive enable_fls = std::bind(&SchedulerNode::ProcessEnableFLS, this, std::placeholders::_1); callbacks_["/enableFLS"] = enable_fls; - http_server_->RegisterRoute("/enableFLS", &callbacks_["/enableFLS"]); + (void)http_server_->RegisterRoute("/enableFLS", &callbacks_["/enableFLS"]); OnRequestReceive disable_fls = std::bind(&SchedulerNode::ProcessDisableFLS, this, std::placeholders::_1); callbacks_["/disableFLS"] = disable_fls; - http_server_->RegisterRoute("/disableFLS", &callbacks_["/disableFLS"]); + (void)http_server_->RegisterRoute("/disableFLS", &callbacks_["/disableFLS"]); + + OnRequestReceive scale_out_rollback = std::bind(&SchedulerNode::ProcessScaleoutRollback, this, std::placeholders::_1); + callbacks_["/scaleoutRollback"] = scale_out_rollback; + (void)http_server_->RegisterRoute("/scaleoutRollback", &callbacks_["/scaleoutRollback"]); if (!http_server_->InitServer()) { MS_LOG(EXCEPTION) << "The scheduler init http server failed."; @@ -1075,6 +1225,78 @@ void SchedulerNode::StopRestfulServer() { restful_thread_->join(); } } + +void SchedulerNode::InitNodeMetaData() { + ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); + clusterConfig.scheduler_host = PSContext::instance()->scheduler_host(); + clusterConfig.scheduler_port = PSContext::instance()->scheduler_port(); + clusterConfig.initial_worker_num = PSContext::instance()->initial_worker_num(); + clusterConfig.initial_server_num = PSContext::instance()->initial_server_num(); + MS_LOG(INFO) << "The cluster worker num:" << clusterConfig.initial_worker_num + << ", the server num:" << clusterConfig.initial_server_num + << ", the scheduler ip:" << clusterConfig.scheduler_host + << ", the scheduler port:" << clusterConfig.scheduler_port; +} + +bool SchedulerNode::RecoverScheduler() { + MS_EXCEPTION_IF_NULL(config_); + if (config_->Exists(kKeyRecovery)) { + MS_LOG(INFO) << "The scheduler node is support recovery."; + scheduler_recovery_ = std::make_unique(); + MS_EXCEPTION_IF_NULL(scheduler_recovery_); + (void)scheduler_recovery_->Initialize(config_->Get(kKeyRecovery, "")); + (void)scheduler_recovery_->InitializeNodes(config_->Get(kKeyRecovery, "")); + + return scheduler_recovery_->Recover(); + } + return false; +} + +void SchedulerNode::PersistMetaData() { + if (scheduler_recovery_ == nullptr) { + MS_LOG(WARNING) << "scheduler recovery is null, so don't persist meta data"; + return; + } + if (config_->Exists(kKeyRecovery)) { + ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); + clusterConfig.initial_worker_num = node_manager_.worker_num(); + clusterConfig.initial_server_num = node_manager_.server_num(); + clusterConfig.initial_total_node_num = node_manager_.total_node_num(); + clusterConfig.initial_next_worker_rank_id = node_manager_.next_worker_rank_id(); + clusterConfig.initial_next_server_rank_id = node_manager_.next_server_rank_id(); + clusterConfig.initial_registered_nodes_infos.clear(); + clusterConfig.initial_registered_nodes_infos = node_manager_.registered_nodes_info(); + + scheduler_recovery_->Persist(clusterConfig); + scheduler_recovery_->PersistNodesInfo(clusterConfig); + } +} + +bool SchedulerNode::CheckIfNodeDisconnected() const { + return UintToInt(register_connection_fd_.size()) != node_manager_.total_node_num(); +} + +void SchedulerNode::BroadcastTimeoutEvent() { + core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); + auto initial_node_infos = clusterConfig.initial_registered_nodes_infos; + const uint32_t event = static_cast(ps::UserDefineEvent::kNodeTimeout); + MS_LOG(INFO) << "Broad timeout event:" << event; + for (const auto kvs : initial_node_infos) { + auto client = GetOrCreateClient(kvs.second); + SendEvent(client, event); + } + MS_LOG(INFO) << "Broad timeout event finish."; +} + +void SchedulerNode::SetRegisterConnectionFd(const std::shared_ptr &conn, const std::string &node_id) { + int fd = conn->GetFd(); + if (register_connection_fd_.count(fd) > 0) { + MS_LOG(WARNING) << "This server has contained the fd:" << fd; + return; + } + MS_LOG(INFO) << "register client fd:" << fd << ", register client id:" << node_id; + register_connection_fd_[fd] = node_id; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/scheduler_node.h b/mindspore/ccsrc/ps/core/scheduler_node.h index b70f40fd886..daa12beef55 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.h +++ b/mindspore/ccsrc/ps/core/scheduler_node.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,8 +25,8 @@ #include #include #include +#include -#include "utils/hash_map.h" #include "ps/core/cluster_config.h" #include "ps/ps_context.h" #include "ps/core/communicator/tcp_client.h" @@ -56,7 +56,8 @@ class SchedulerNode : public Node { client_thread_(nullptr), is_client_started_(false), leader_scaler_(nullptr), - scheduler_recovery_(nullptr) {} + scheduler_recovery_(nullptr), + is_worker_timeout_(false) {} ~SchedulerNode() override; typedef void (SchedulerNode::*ResponseHandler)(const std::shared_ptr &server, @@ -135,6 +136,10 @@ class SchedulerNode : public Node { // Handle the disable FLS http request Synchronously. void ProcessDisableFLS(const std::shared_ptr &resp); + // Handle the scale out rollback http request, then delegate to the leader scaler to + // process scale out rollback asynchronously. + void ProcessScaleoutRollback(const std::shared_ptr &resp); + // check whether the cluster is in the ready state. RequestProcessResult CheckIfClusterReady(); @@ -142,12 +147,27 @@ class SchedulerNode : public Node { RequestProcessResult CheckIfNodeIdLegal(const std::vector &node_ids); void StartRestfulServer(const std::string &address, std::uint16_t port, size_t thread_num = 10); + void StopRestfulServer(); + void InitNodeMetaData(); + + bool RecoverScheduler(); + + void PersistMetaData(); + + bool CheckIfNodeDisconnected() const; + + void RunRecovery(); + + void BroadcastTimeoutEvent(); + + void SetRegisterConnectionFd(const std::shared_ptr &conn, const std::string &node_id); + std::shared_ptr server_; std::unique_ptr scheduler_thread_; std::unique_ptr update_state_thread_; - mindspore::HashMap handlers_; + std::unordered_map handlers_; NodeManager node_manager_; @@ -155,14 +175,15 @@ class SchedulerNode : public Node { std::unique_ptr restful_thread_; std::shared_ptr http_server_; - mindspore::HashMap> connected_nodes_; + std::unordered_map> connected_nodes_; + std::shared_ptr client_to_scheduler_; std::unique_ptr client_thread_; std::atomic is_client_started_; std::unique_ptr leader_scaler_; - mindspore::HashMap callbacks_; + std::unordered_map callbacks_; // Used to persist and obtain metadata information for scheduler. std::unique_ptr scheduler_recovery_; @@ -171,6 +192,10 @@ class SchedulerNode : public Node { std::vector scale_in_node_ids_; std::unique_ptr instance_manager_; + + std::atomic is_worker_timeout_; + // This is a map of register connection fd to client node id + std::unordered_map register_connection_fd_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/scheduler_recovery.cc b/mindspore/ccsrc/ps/core/scheduler_recovery.cc index 47a1a56e184..e0422f1d5ee 100644 --- a/mindspore/ccsrc/ps/core/scheduler_recovery.cc +++ b/mindspore/ccsrc/ps/core/scheduler_recovery.cc @@ -19,17 +19,115 @@ namespace mindspore { namespace ps { namespace core { -void SchedulerRecovery::Persist(const std::string &key, const std::string &value) { - MS_EXCEPTION_IF_NULL(recovery_storage_); - recovery_storage_->Put(key, value); -} - std::string SchedulerRecovery::GetMetadata(const std::string &key) { MS_EXCEPTION_IF_NULL(recovery_storage_); return recovery_storage_->Get(key, ""); } -bool SchedulerRecovery::Recover() { return true; } +bool SchedulerRecovery::Recover() { + if (recovery_storage_ == nullptr) { + return false; + } + core::ClusterConfig &clusterConfig = PSContext::instance()->cluster_config(); + + // 1. recover worker num + if (recovery_storage_->Exists(kRecoveryWorkerNum)) { + clusterConfig.initial_worker_num = + std::strtol(recovery_storage_->Get(kRecoveryWorkerNum, "").c_str(), nullptr, kBase); + } else { + clusterConfig.initial_worker_num = PSContext::instance()->initial_worker_num(); + } + + // 2. recover server num + if (recovery_storage_->Exists(kRecoveryServerNum)) { + clusterConfig.initial_server_num = + std::strtol(recovery_storage_->Get(kRecoveryServerNum, "").c_str(), nullptr, kBase); + } else { + clusterConfig.initial_server_num = PSContext::instance()->initial_server_num(); + } + + // 3. recover scheduler ip + if (recovery_storage_->Exists(kRecoverySchedulerIp)) { + clusterConfig.scheduler_host = recovery_storage_->GetString(kRecoverySchedulerIp, ""); + } else { + clusterConfig.scheduler_host = PSContext::instance()->scheduler_host(); + } + + // 4. recover scheduler port + if (recovery_storage_->Exists(kRecoverySchedulerPort)) { + clusterConfig.scheduler_port = + std::strtol(recovery_storage_->Get(kRecoverySchedulerPort, "").c_str(), nullptr, kBase); + } else { + clusterConfig.scheduler_port = PSContext::instance()->scheduler_port(); + } + + MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num + << ", the server num:" << clusterConfig.initial_server_num + << ", the scheduler ip:" << clusterConfig.scheduler_host + << ", the scheduler port:" << clusterConfig.scheduler_port; + + if (scheduler_recovery_storage_ == nullptr) { + MS_LOG(WARNING) << "scheduler recovery storage is null. return false"; + return false; + } + // 5. recover total node num + if (scheduler_recovery_storage_->Exists(kRecoveryTotalNodeNum)) { + clusterConfig.initial_total_node_num = + std::strtol(scheduler_recovery_storage_->Get(kRecoveryTotalNodeNum, "").c_str(), nullptr, kBase); + } + + // 6. recover next worker rank id + if (scheduler_recovery_storage_->Exists(kRecoveryNextWorkerRankId)) { + clusterConfig.initial_next_worker_rank_id = + std::strtol(scheduler_recovery_storage_->Get(kRecoveryNextWorkerRankId, "").c_str(), nullptr, kBase); + } + + // 7. recover next server rank id + if (scheduler_recovery_storage_->Exists(kRecoveryNextServerRankId)) { + clusterConfig.initial_next_server_rank_id = + std::strtol(scheduler_recovery_storage_->Get(kRecoveryNextServerRankId, "").c_str(), nullptr, kBase); + } + + // 8. recover register nodes info + if (scheduler_recovery_storage_->Exists(kRecoveryRegisteredNodesInfos)) { + auto node_ids = scheduler_recovery_storage_->GetVector(kRecoveryRegisteredNodesInfos); + std::unordered_map nodes_infos; + for (auto elem : node_ids) { + std::string port = elem.at("port"); + std::string rank_id = elem.at("rank_id"); + + NodeInfo node_info; + node_info.ip_ = elem.at("ip"); + node_info.port_ = std::strtol(port.c_str(), nullptr, kBase); + node_info.node_id_ = elem.at("node_id"); + node_info.rank_id_ = std::strtol(rank_id.c_str(), nullptr, kBase); + node_info.is_alive = CommUtil::StringToBool(elem.at("alive")); + node_info.node_role_ = CommUtil::StringToNodeRole(elem.at("role")); + + nodes_infos[node_info.node_id_] = node_info; + } + clusterConfig.initial_registered_nodes_infos = nodes_infos; + } + + MS_LOG(INFO) << "The worker num:" << clusterConfig.initial_worker_num + << ", the server num:" << clusterConfig.initial_server_num + << ", the scheduler ip:" << clusterConfig.scheduler_host + << ", the scheduler port:" << clusterConfig.scheduler_port + << ", the initial total node num:" << clusterConfig.initial_total_node_num + << ", the initial next worker rank id:" << clusterConfig.initial_next_worker_rank_id + << ", the initial next server rank id:" << clusterConfig.initial_next_server_rank_id; + + if (!clusterConfig.initial_registered_nodes_infos.empty()) { + for (const auto kvs : clusterConfig.initial_registered_nodes_infos) { + MS_LOG(INFO) << "The ip:" << kvs.second.ip_ << ", the port:" << kvs.second.port_ + << ", the node_id:" << kvs.second.node_id_ + << ", the node_role:" << CommUtil::NodeRoleToString(kvs.second.node_role_) + << ", the rank_id_:" << kvs.second.rank_id_ + << ", the is_alive:" << CommUtil::BoolToString(kvs.second.is_alive); + } + } + return true; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/scheduler_recovery.h b/mindspore/ccsrc/ps/core/scheduler_recovery.h index 08488ffa094..9fcacda3fdd 100644 --- a/mindspore/ccsrc/ps/core/scheduler_recovery.h +++ b/mindspore/ccsrc/ps/core/scheduler_recovery.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "ps/constants.h" #include "utils/log_adapter.h" @@ -29,6 +30,7 @@ #include "ps/ps_context.h" #include "ps/core/recovery_base.h" #include "ps/core/scheduler_node.h" +#include "ps/core/node_info.h" namespace mindspore { namespace ps { @@ -39,9 +41,6 @@ class SchedulerRecovery : public RecoveryBase { SchedulerRecovery() = default; ~SchedulerRecovery() override = default; - // Persist metadata to storage. - void Persist(const std::string &key, const std::string &value); - bool Recover() override; // Get metadata from storage. diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 9e388ee1f8b..357f474ee90 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -24,6 +24,9 @@ bool ServerNode::Start(const uint32_t &timeout) { MS_LOG(INFO) << "[Server start]: 1. Begin to start server node!"; Initialize(); Register(client_to_scheduler_); + if (node_info_.rank_id_ == UINT32_MAX) { + MS_LOG(EXCEPTION) << "Register to scheduler failed, so finish the node."; + } MS_LOG(INFO) << "[Server start]: 4. The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " the node id:" << node_info_.node_id_ << " successfully registered to the scheduler!"; @@ -34,7 +37,6 @@ bool ServerNode::Start(const uint32_t &timeout) { MS_LOG(ERROR) << "Start server node timeout!"; return false; } - MsException::Instance().CheckException(); MS_LOG(INFO) << "[Server start]: 6. Successfully start server node!"; return true; @@ -61,6 +63,7 @@ void ServerNode::Initialize() { if (!InitClientToScheduler()) { MS_LOG(EXCEPTION) << "Server node connect to scheduler timedout!"; } + InitClientToServer(); is_already_stopped_ = false; MS_LOG(INFO) << "[Server start]: 3. Server node crete tcp client to scheduler successful!"; } diff --git a/mindspore/ccsrc/ps/core/server_node.h b/mindspore/ccsrc/ps/core/server_node.h index 73cc8e35286..81521c180ab 100644 --- a/mindspore/ccsrc/ps/core/server_node.h +++ b/mindspore/ccsrc/ps/core/server_node.h @@ -1,5 +1,5 @@ /** - * Copyright 2020-2021 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,8 @@ #include #include #include +#include -#include "utils/hash_map.h" #include "ps/core/cluster_metadata.h" #include "ps/core/cluster_config.h" #include "ps/ps_context.h" diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index 61ff54087d3..1617be36b6c 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -24,6 +24,9 @@ bool WorkerNode::Start(const uint32_t &timeout) { MS_LOG(INFO) << "[Worker start]: 1. Begin to start worker node!"; Initialize(); Register(client_to_scheduler_); + if (node_info_.rank_id_ == UINT32_MAX) { + MS_LOG(EXCEPTION) << "Register to scheduler failed, so finish the node."; + } MS_LOG(INFO) << "[Worker start]: 4. The node role:" << CommUtil::NodeRoleToString(node_info_.node_role_) << " the node id:" << node_info_.node_id_ << " successfully registered to the scheduler!"; @@ -34,7 +37,6 @@ bool WorkerNode::Start(const uint32_t &timeout) { MS_LOG(ERROR) << "Start Worker node timeout!"; return false; } - MsException::Instance().CheckException(); MS_LOG(INFO) << "[Worker start]: 6. Successfully start worker node!"; return true; @@ -61,6 +63,7 @@ void WorkerNode::Initialize() { if (!InitClientToScheduler()) { MS_LOG(EXCEPTION) << "Worker node connect to scheduler timeout!"; } + InitClientToServer(); is_already_stopped_ = false; MS_LOG(INFO) << "[Worker start]: 3. Worker node crete tcp client to scheduler successful!"; } diff --git a/mindspore/ccsrc/ps/core/worker_node.h b/mindspore/ccsrc/ps/core/worker_node.h index c003e840c8f..0789a380670 100644 --- a/mindspore/ccsrc/ps/core/worker_node.h +++ b/mindspore/ccsrc/ps/core/worker_node.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_PS_CORE_WORKER_NODE_H_ -#define MINDSPORE_CCSRC_PS_CORE_WORKER_NODE_H_ +#ifndef MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ +#define MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ #include #include @@ -48,4 +48,4 @@ class WorkerNode : public AbstractNode { } // namespace ps } // namespace mindspore -#endif // MINDSPORE_CCSRC_PS_CORE_WORKER_NODE_H_ +#endif // MINDSPORE_CCSRC_PS_CORE_CLIENT_NODE_H_ diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 1b92625b6f2..78e6d228d5c 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -28,7 +28,7 @@ namespace ps { std::shared_ptr PSContext::instance() { static std::shared_ptr ps_instance = nullptr; if (ps_instance == nullptr) { - ps_instance.reset(new (std::nothrow) PSContext()); + ps_instance.reset(new PSContext()); } return ps_instance; } @@ -207,6 +207,7 @@ void PSContext::set_encrypt_type(const std::string &encrypt_type) { } encrypt_type_ = encrypt_type; } + const std::string &PSContext::encrypt_type() const { return encrypt_type_; } void PSContext::set_dp_eps(float dp_eps) { @@ -284,7 +285,7 @@ void PSContext::GenerateResetterRound() { bool is_parameter_server_mode = false; bool is_federated_learning_mode = false; bool is_mixed_training_mode = false; - bool use_pairwise_encrypt = (encrypt_type_ == kPWEncryptType); + bool is_pairwise_encrypt = (encrypt_type_ == kPWEncryptType); if (server_mode_ == kServerModePS) { is_parameter_server_mode = true; @@ -297,9 +298,11 @@ void PSContext::GenerateResetterRound() { << " or " << kServerModeHybrid; return; } - + const int training_mode_offset = 2; + const int pairwise_encrypt_offset = 3; binary_server_context = ((unsigned int)is_parameter_server_mode) | ((unsigned int)is_federated_learning_mode << 1) | - ((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)use_pairwise_encrypt << 3); + ((unsigned int)is_mixed_training_mode << training_mode_offset) | + ((unsigned int)is_pairwise_encrypt << pairwise_encrypt_offset); if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { resetter_round_ = ResetterRound::kNoNeedToReset; } else { @@ -404,9 +407,9 @@ void PSContext::set_worker_step_num_per_iteration(uint64_t worker_step_num_per_i uint64_t PSContext::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; } -bool PSContext::enable_ssl() const { return enable_ssl_; } +void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; } -void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; } +bool PSContext::secure_aggregation() const { return secure_aggregation_; } core::ClusterConfig &PSContext::cluster_config() { if (cluster_config_ == nullptr) { @@ -416,8 +419,30 @@ core::ClusterConfig &PSContext::cluster_config() { return *cluster_config_; } -void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; } +void PSContext::set_root_first_ca_path(const std::string &root_first_ca_path) { + root_first_ca_path_ = root_first_ca_path; +} +void PSContext::set_root_second_ca_path(const std::string &root_second_ca_path) { + root_second_ca_path_ = root_second_ca_path; +} +std::string PSContext::root_first_ca_path() const { return root_first_ca_path_; } +std::string PSContext::root_second_ca_path() const { return root_second_ca_path_; } + +void PSContext::set_pki_verify(bool pki_verify) { pki_verify_ = pki_verify; } +bool PSContext::pki_verify() const { return pki_verify_; } + +void PSContext::set_replay_attack_time_diff(uint64_t replay_attack_time_diff) { + replay_attack_time_diff_ = replay_attack_time_diff; +} + +uint64_t PSContext::replay_attack_time_diff() const { return replay_attack_time_diff_; } + +std::string PSContext::equip_crl_path() const { return equip_crl_path_; } + +void PSContext::set_equip_crl_path(const std::string &equip_crl_path) { equip_crl_path_ = equip_crl_path; } + +void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; } uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; } void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; } @@ -428,10 +453,18 @@ void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; } const std::string &PSContext::node_id() const { return node_id_; } +bool PSContext::enable_ssl() const { return enable_ssl_; } + +void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; } + std::string PSContext::client_password() const { return client_password_; } void PSContext::set_client_password(const std::string &password) { client_password_ = password; } std::string PSContext::server_password() const { return server_password_; } void PSContext::set_server_password(const std::string &password) { server_password_ = password; } + +std::string PSContext::http_url_prefix() const { return http_url_prefix_; } + +void PSContext::set_http_url_prefix(const std::string &http_url_prefix) { http_url_prefix_ = http_url_prefix; } } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index 60b2823cc6f..890f570ac9b 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -81,19 +81,15 @@ class PSContext { void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const; void set_cache_enable(bool cache_enable) const; void set_rank_id(uint32_t rank_id) const; - bool enable_ssl() const; - void set_enable_ssl(bool enabled); - - std::string client_password() const; - void set_client_password(const std::string &password); - std::string server_password() const; - void set_server_password(const std::string &password); // In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set // by ps_context. void set_server_mode(const std::string &server_mode); const std::string &server_mode() const; + void set_encrypt_type(const std::string &encrypt_type); + const std::string &encrypt_type() const; + void set_ms_role(const std::string &role); void set_worker_num(uint32_t worker_num); @@ -163,13 +159,9 @@ class PSContext { void set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration); uint64_t worker_step_num_per_iteration() const; - core::ClusterConfig &cluster_config(); - - void set_scheduler_manage_port(uint16_t sched_port); - uint16_t scheduler_manage_port() const; - - void set_config_file_path(const std::string &path); - std::string config_file_path() const; + // Set true if using secure aggregation for federated learning. + void set_secure_aggregation(bool secure_aggregation); + bool secure_aggregation() const; void set_dp_eps(float dp_eps); float dp_eps() const; @@ -180,19 +172,50 @@ class PSContext { void set_dp_norm_clip(float dp_norm_clip); float dp_norm_clip() const; - void set_encrypt_type(const std::string &encrypt_type); - const std::string &encrypt_type() const; + core::ClusterConfig &cluster_config(); + + void set_root_first_ca_path(const std::string &root_first_ca_path); + void set_root_second_ca_path(const std::string &root_second_ca_path); + + std::string root_first_ca_path() const; + std::string root_second_ca_path() const; + + void set_pki_verify(bool pki_verify); + bool pki_verify() const; + + void set_equip_crl_path(const std::string &equip_crl_path); + std::string equip_crl_path() const; + + void set_replay_attack_time_diff(uint64_t replay_attack_time_diff); + uint64_t replay_attack_time_diff() const; + + void set_scheduler_manage_port(uint16_t sched_port); + uint16_t scheduler_manage_port() const; + + void set_config_file_path(const std::string &path); + std::string config_file_path() const; void set_node_id(const std::string &node_id); const std::string &node_id() const; + bool enable_ssl() const; + void set_enable_ssl(bool enabled); + + std::string client_password() const; + void set_client_password(const std::string &password); + + std::string server_password() const; + void set_server_password(const std::string &password); + + std::string http_url_prefix() const; + void set_http_url_prefix(const std::string &http_url_prefix); + private: PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), - enable_ssl_(false), rank_id_(0), worker_num_(0), server_num_(0), @@ -218,20 +241,26 @@ class PSContext { worker_step_num_per_iteration_(65), secure_aggregation_(false), cluster_config_(nullptr), + root_first_ca_path_(""), + root_second_ca_path_(""), + pki_verify_(false), + equip_crl_path_(""), + replay_attack_time_diff_(60000), scheduler_manage_port_(11202), config_file_path_(""), + node_id_(""), dp_eps_(50), dp_delta_(0.01), dp_norm_clip_(1.0), encrypt_type_(kNotEncryptType), - node_id_(""), + enable_ssl_(false), client_password_(""), - server_password_("") {} + server_password_(""), + http_url_prefix_("") {} bool ps_enabled_; bool is_worker_; bool is_pserver_; bool is_sched_; - bool enable_ssl_; uint32_t rank_id_; uint32_t worker_num_; uint32_t server_num_; @@ -298,12 +327,30 @@ class PSContext { // The cluster config read through environment variables, the value does not change. std::unique_ptr cluster_config_; + // The first generation CBG root certificate + std::string root_first_ca_path_; + + // The second generation CBG root certificate + std::string root_second_ca_path_; + + // if open pki verify + bool pki_verify_; + + // The second generation CBG root CRL + std::string equip_crl_path_; + + // The replay attack time diff + uint64_t replay_attack_time_diff_; + // The port used by scheduler to receive http requests for scale out or scale in. uint16_t scheduler_manage_port_; // The path of the configuration file, used to configure the certification path and persistent storage type, etc. std::string config_file_path_; + // Unique id of the node + std::string node_id_; + // Epsilon budget of differential privacy mechanism. Used in federated learning for now. float dp_eps_; @@ -316,13 +363,14 @@ class PSContext { // Secure mechanism for federated learning. Used in federated learning for now. std::string encrypt_type_; - // Unique id of the node - std::string node_id_; - + // Whether to enable ssl for network communication. + bool enable_ssl_; // Password used to decode p12 file. std::string client_password_; // Password used to decode p12 file. std::string server_password_; + // http url prefix for http communication + std::string http_url_prefix_; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index 291897b001a..c814dca7bbb 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -20,17 +20,6 @@ from mindspore._c_expression import PSContext _ps_context = None -_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", - "start_fl_job_threshold", "start_fl_job_time_window", "update_model_time_window", - "fl_iteration_num", "client_epoch_num", "client_batch_size", "scheduler_manage_port", - "cipher_time_window", "reconstruct_secrets_threshold"] - -_check_non_negative_int_keys = ["worker_num"] - -_check_positive_float_keys = ["update_model_ratio", "client_learning_rate"] - -_check_port_keys = ["scheduler_port", "fl_server_port", "scheduler_manage_port"] - def ps_context(): """ @@ -68,6 +57,11 @@ _set_ps_context_func_map = { "client_batch_size": ps_context().set_client_batch_size, "client_learning_rate": ps_context().set_client_learning_rate, "worker_step_num_per_iteration": ps_context().set_worker_step_num_per_iteration, + "root_first_ca_path": ps_context().set_root_first_ca_path, + "root_second_ca_path": ps_context().set_root_second_ca_path, + "pki_verify": ps_context().set_pki_verify, + "equip_crl_path": ps_context().set_equip_crl_path, + "replay_attack_time_diff": ps_context().set_replay_attack_time_diff, "enable_ssl": ps_context().set_enable_ssl, "client_password": ps_context().set_client_password, "server_password": ps_context().set_server_password, @@ -76,7 +70,8 @@ _set_ps_context_func_map = { "dp_eps": ps_context().set_dp_eps, "dp_delta": ps_context().set_dp_delta, "dp_norm_clip": ps_context().set_dp_norm_clip, - "encrypt_type": ps_context().set_encrypt_type + "encrypt_type": ps_context().set_encrypt_type, + "http_url_prefix": ps_context().set_http_url_prefix } _get_ps_context_func_map = { @@ -95,7 +90,7 @@ _get_ps_context_func_map = { "update_model_ratio": ps_context().update_model_ratio, "update_model_time_window": ps_context().update_model_time_window, "share_secrets_ratio": ps_context().share_secrets_ratio, - "cipher_time_window": ps_context().set_cipher_time_window, + "cipher_time_window": ps_context().cipher_time_window, "reconstruct_secrets_threshold": ps_context().reconstruct_secrets_threshold, "fl_name": ps_context().fl_name, "fl_iteration_num": ps_context().fl_iteration_num, @@ -103,13 +98,33 @@ _get_ps_context_func_map = { "client_batch_size": ps_context().client_batch_size, "client_learning_rate": ps_context().client_learning_rate, "worker_step_num_per_iteration": ps_context().worker_step_num_per_iteration, + "dp_eps": ps_context().dp_eps, + "dp_delta": ps_context().dp_delta, + "dp_norm_clip": ps_context().dp_norm_clip, + "encrypt_type": ps_context().encrypt_type, + "root_first_ca_path": ps_context().root_first_ca_path, + "root_second_ca_path": ps_context().root_second_ca_path, + "pki_verify": ps_context().pki_verify, + "equip_crl_path": ps_context().equip_crl_path, + "replay_attack_time_diff": ps_context().replay_attack_time_diff, "enable_ssl": ps_context().enable_ssl, "client_password": ps_context().client_password, "server_password": ps_context().server_password, "scheduler_manage_port": ps_context().scheduler_manage_port, - "config_file_path": ps_context().config_file_path + "config_file_path": ps_context().config_file_path, + "http_url_prefix": ps_context().http_url_prefix } +_check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port", + "start_fl_job_threshold", "start_fl_job_time_window", "update_model_time_window", + "fl_iteration_num", "client_epoch_num", "client_batch_size", "cipher_time_window", + "reconstruct_secrets_threshold"] + +_check_non_negative_int_keys = ["worker_num"] + +_check_positive_float_keys = ["update_model_ratio", "client_learning_rate"] + +_check_port_keys = ["scheduler_port", "fl_server_port"] def _get_ps_mode_rank(): ps_rank = ps_context().ps_rank_id() diff --git a/mindspore/schema/cipher.fbs b/mindspore/schema/cipher.fbs index 4eaabb5f4d2..45009b1a14a 100644 --- a/mindspore/schema/cipher.fbs +++ b/mindspore/schema/cipher.fbs @@ -33,6 +33,10 @@ table ClientPublicKeys { s_pk: [ubyte]; pw_iv: [ubyte]; pw_salt: [ubyte]; + timestamp: string; + iteration: int; + signature: [ubyte]; + certificate_chain: [string]; } table ClientShare { @@ -50,6 +54,8 @@ table RequestExchangeKeys{ ind_iv:[ubyte]; pw_iv:[ubyte]; pw_salt:[ubyte]; + signature:[ubyte]; + certificate_chain:[string]; } table ResponseExchangeKeys{ @@ -63,6 +69,7 @@ table GetExchangeKeys{ fl_id:string; iteration:int; timestamp:string; + signature:[ubyte]; } table ReturnExchangeKeys{ @@ -72,11 +79,13 @@ table ReturnExchangeKeys{ next_req_time:string; } + table RequestShareSecrets{ fl_id:string; encrypted_shares:[ClientShare]; iteration:int; timestamp:string; + signature:[ubyte]; } table ResponseShareSecrets{ @@ -90,6 +99,7 @@ table GetShareSecrets{ fl_id:string; iteration:int; timestamp:string; + signature:[ubyte]; } table ReturnShareSecrets{ @@ -103,6 +113,7 @@ table GetClientList{ fl_id:string; iteration:int; timestamp:string; + signature:[ubyte]; } table ReturnClientList{ @@ -113,11 +124,47 @@ table ReturnClientList{ next_req_time:string; } +table ClientListSign { + fl_id:string; + signature:[ubyte]; +} + +table SendClientListSign{ + fl_id:string; + iteration:int; + timestamp:string; + signature:[ubyte]; + req_signature:[ubyte]; +} + +table ResponseClientListSign{ + retcode:int; + reason:string; + iteration:int; + next_req_time:string; +} + +table RequestAllClientListSign{ + fl_id:string; + iteration:int; + timestamp:string; + signature:[ubyte]; +} + +table ReturnAllClientListSign{ + retcode:int; + reason:string; + client_list_sign:[ClientListSign]; + iteration:int; + next_req_time:string; +} + table SendReconstructSecret{ fl_id:string; reconstruct_secret_shares:[ClientShare]; iteration:int; timestamp:string; + signature:[ubyte]; } table ReconstructSecret{ diff --git a/mindspore/schema/fl_job.fbs b/mindspore/schema/fl_job.fbs index e7a3d60a2b0..656f4dbe5d0 100644 --- a/mindspore/schema/fl_job.fbs +++ b/mindspore/schema/fl_job.fbs @@ -22,11 +22,8 @@ file_extension "fl"; enum ResponseCode: int { SUCCEED=200, - SucNotReady=201, - RepeatRequest=202, - SucNotMatch=204, + SucNotReady=201, OutOfTime=300, - NotSelected=301, RequestError=400, SystemError=500 } @@ -56,6 +53,11 @@ table RequestFLJob{ iteration:int; data_size:int; timestamp:string; + sign_data:[ubyte]; + key_attestation:string; + equip_cert:string; + equip_ca_cert:string; + root_cert:string; } table ResponseFLJob { retcode:int; @@ -88,7 +90,9 @@ table RequestUpdateModel{ iteration:int; feature_map:[FeatureMap]; timestamp:string; + signature:[ubyte]; } + table ResponseUpdateModel{ retcode:int; reason:string; diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 4f764cdbe9d..ee8eba0484e 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -195,7 +195,6 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gp list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/fl/server/kernel/apply_momentum_kernel.cc") -list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/fl/server/kernel/sgd_kernel.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc") diff --git a/tests/ut/cpp/ps/core/cluster_metadata_test.cc b/tests/ut/cpp/ps/core/cluster_metadata_test.cc index 1486a3c9157..7367b22edb2 100644 --- a/tests/ut/cpp/ps/core/cluster_metadata_test.cc +++ b/tests/ut/cpp/ps/core/cluster_metadata_test.cc @@ -43,7 +43,7 @@ TEST_F(TestClusterConfig, HeartbeatInterval) { common::SetEnv(kEnvSchedulerHost, host.c_str()); common::SetEnv(kEnvSchedulerPort, port.c_str()); PSContext::instance()->SetPSEnable(true); - EXPECT_EQ(300, PSContext::instance()->cluster_config().cluster_available_timeout); + EXPECT_EQ(900, PSContext::instance()->cluster_config().cluster_available_timeout); } } // namespace core } // namespace ps