From dc3d6dc9157ada435ea0ce0e035f997dadf68852 Mon Sep 17 00:00:00 2001 From: chendongsheng Date: Wed, 14 Jul 2021 15:53:17 +0800 Subject: [PATCH] added cert --- mindspore/ccsrc/ps/constants.h | 7 ++ mindspore/ccsrc/ps/core/abstract_node.cc | 11 +- mindspore/ccsrc/ps/core/comm_util.cc | 15 +++ mindspore/ccsrc/ps/core/comm_util.h | 4 + .../ccsrc/ps/core/communicator/ssl_wrapper.h | 2 + .../ccsrc/ps/core/communicator/tcp_client.cc | 108 +++++++++++++----- .../ccsrc/ps/core/communicator/tcp_client.h | 6 +- .../ccsrc/ps/core/communicator/tcp_server.cc | 41 ++++++- .../ccsrc/ps/core/communicator/tcp_server.h | 5 +- mindspore/ccsrc/ps/core/configuration.h | 3 + mindspore/ccsrc/ps/core/file_configuration.cc | 3 + mindspore/ccsrc/ps/core/file_configuration.h | 6 +- mindspore/ccsrc/ps/core/scheduler_node.cc | 11 +- mindspore/ccsrc/ps/core/server_node.cc | 2 +- mindspore/ccsrc/ps/core/worker_node.cc | 2 +- tests/st/fl/hybrid_lenet/config.json | 9 +- tests/st/fl/mobile/config.json | 9 +- tests/ut/cpp/ps/core/tcp_client_tests.cc | 6 +- tests/ut/cpp/ps/core/tcp_pb_server_test.cc | 6 +- 19 files changed, 206 insertions(+), 50 deletions(-) diff --git a/mindspore/ccsrc/ps/constants.h b/mindspore/ccsrc/ps/constants.h index 714fe26c3b5..b6f953d5747 100644 --- a/mindspore/ccsrc/ps/constants.h +++ b/mindspore/ccsrc/ps/constants.h @@ -108,6 +108,13 @@ constexpr char kRecoveryServerNum[] = "server_num"; constexpr char kRecoverySchedulerIp[] = "scheduler_ip"; constexpr char kRecoverySchedulerPort[] = "scheduler_port"; +constexpr char kServerCertPath[] = "server_cert_path"; +constexpr char kServerPassword[] = "server_password"; +constexpr char kCrlPath[] = "crl_path"; +constexpr char kClientCertPath[] = "client_cert_path"; +constexpr char kClientPassword[] = "client_password"; +constexpr char kCaCertPath[] = "ca_cert_path"; + using DataPtr = std::shared_ptr; using VectorPtr = std::shared_ptr>; using Key = uint64_t; diff --git a/mindspore/ccsrc/ps/core/abstract_node.cc b/mindspore/ccsrc/ps/core/abstract_node.cc index 2a4ed60a8ae..c5b5dc5bc76 100644 --- a/mindspore/ccsrc/ps/core/abstract_node.cc +++ b/mindspore/ccsrc/ps/core/abstract_node.cc @@ -705,7 +705,11 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) { } bool AbstractNode::InitClientToScheduler() { - client_to_scheduler_ = std::make_shared(scheduler_ip_, scheduler_port_); + if (config_ == nullptr) { + MS_LOG(WARNING) << "The config is empty."; + return false; + } + client_to_scheduler_ = std::make_shared(scheduler_ip_, scheduler_port_, config_.get()); client_to_scheduler_->SetMessageCallback( [&](const std::shared_ptr &meta, const Protos &, const void *data, size_t size) { try { @@ -750,9 +754,12 @@ const std::shared_ptr &AbstractNode::GetOrCreateTcpClient(const uint3 if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) { MS_LOG(EXCEPTION) << "Worker receive nodes info from scheduler failed!"; } + if (config_ == nullptr) { + MS_LOG(EXCEPTION) << "The config is empty."; + } std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first; uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second; - auto client = std::make_shared(ip, port); + auto client = std::make_shared(ip, port, config_.get()); client->SetMessageCallback([&](const std::shared_ptr &meta, const Protos &protos, const void *data, size_t size) { switch (meta->cmd()) { diff --git a/mindspore/ccsrc/ps/core/comm_util.cc b/mindspore/ccsrc/ps/core/comm_util.cc index f2a9b7efa5f..41ff0d0b429 100644 --- a/mindspore/ccsrc/ps/core/comm_util.cc +++ b/mindspore/ccsrc/ps/core/comm_util.cc @@ -175,6 +175,21 @@ std::string CommUtil::ClusterStateToString(const ClusterState &state) { MS_LOG(INFO) << "The cluster state:" << state; return kClusterState.at(state); } + +std::string CommUtil::ParseConfig(const Configuration &config, const std::string &key) { + if (!config.IsInitialized()) { + MS_LOG(INFO) << "The config is not initialized."; + return ""; + } + + if (!const_cast(config).Exists(key)) { + MS_LOG(INFO) << "The key:" << key << " is not exist."; + return ""; + } + + std::string path = config.Get(key, ""); + return path; +} } // namespace core } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/core/comm_util.h b/mindspore/ccsrc/ps/core/comm_util.h index d854328576e..2e127fe1d03 100644 --- a/mindspore/ccsrc/ps/core/comm_util.h +++ b/mindspore/ccsrc/ps/core/comm_util.h @@ -57,6 +57,7 @@ #include "utils/log_adapter.h" #include "ps/ps_context.h" #include "utils/convert_utils_base.h" +#include "ps/core/configuration.h" namespace mindspore { namespace ps { @@ -100,6 +101,9 @@ class CommUtil { // Convert cluster state to string when response the http request. static std::string ClusterStateToString(const ClusterState &state); + // Parse the configuration file according to the key. + static std::string ParseConfig(const Configuration &config, const std::string &key); + private: static std::random_device rd; static std::mt19937_64 gen; diff --git a/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h b/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h index 64551fcb3d6..e7870598e02 100644 --- a/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h +++ b/mindspore/ccsrc/ps/core/communicator/ssl_wrapper.h @@ -22,6 +22,8 @@ #include #include #include +#include +#include #include #include diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_client.cc b/mindspore/ccsrc/ps/core/communicator/tcp_client.cc index 8350c1510ff..bcb66f50916 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_client.cc +++ b/mindspore/ccsrc/ps/core/communicator/tcp_client.cc @@ -37,13 +37,14 @@ event_base *TcpClient::event_base_ = nullptr; std::mutex TcpClient::event_base_mutex_; bool TcpClient::is_started_ = false; -TcpClient::TcpClient(const std::string &address, std::uint16_t port) +TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *config) : event_timeout_(nullptr), buffer_event_(nullptr), server_address_(std::move(address)), server_port_(port), is_stop_(true), - is_connected_(false) { + is_connected_(false), + config_(config) { message_handler_.SetCallback( [this](const std::shared_ptr &meta, const Protos &protos, const void *data, size_t size) { if (message_callback_) { @@ -109,35 +110,9 @@ void TcpClient::Init() { if (!PSContext::instance()->enable_ssl()) { buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); } else { - MS_LOG(INFO) << "Enable ssl support."; - - SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx(false)); - - X509 *cert = SSLWrapper::GetInstance().ReadCertFromFile(kCertificateChain); - if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) { - MS_LOG(EXCEPTION) << "Verify cert time failed."; + if (!EstablishSSL()) { + MS_LOG(EXCEPTION) << "Establish SSL failed."; } - - if (!SSL_CTX_use_certificate_chain_file(SSLWrapper::GetInstance().GetSSLCtx(false), kCertificateChain)) { - MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; - } - - if (!SSL_CTX_use_PrivateKey_file(SSLWrapper::GetInstance().GetSSLCtx(false), kPrivateKey, SSL_FILETYPE_PEM)) { - MS_LOG(EXCEPTION) << "SSL use private key file failed!"; - } - - if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx(false))) { - MS_LOG(EXCEPTION) << "SSL check private key file failed!"; - } - - if (!SSL_CTX_load_verify_locations(SSLWrapper::GetInstance().GetSSLCtx(false), kCAcrt, nullptr)) { - MS_LOG(EXCEPTION) << "SSL load ca location failed!"; - } - - SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(false), SSL_OP_NO_SSLv2); - - buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING, - BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); } MS_EXCEPTION_IF_NULL(buffer_event_); @@ -229,6 +204,79 @@ void TcpClient::NotifyConnected() { connection_cond_.notify_all(); } +bool TcpClient::EstablishSSL() { + MS_LOG(INFO) << "Enable ssl support."; + + if (config_ == nullptr) { + MS_LOG(EXCEPTION) << "The config is empty."; + } + + SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx(false)); + + // 1.Parse the client's certificate and the ciphertext of key. + std::string client_cert = kCertificateChain; + std::string path = CommUtil::ParseConfig(*config_, kClientCertPath); + if (!CommUtil::IsFileExists(path)) { + MS_LOG(WARNING) << "The key:" << kClientCertPath << "'s value is not exist."; + return false; + } + client_cert = path; + + MS_LOG(INFO) << "The client cert path:" << client_cert; + + // 2. Parse the client password. + std::string client_password = CommUtil::ParseConfig(*config_, kClientPassword); + if (!client_password.empty()) { + MS_LOG(WARNING) << "The key:" << kClientPassword << "'s value is empty."; + return false; + } + + MS_LOG(INFO) << "The client password:" << client_password; + + EVP_PKEY *pkey = nullptr; + X509 *cert = nullptr; + STACK_OF(X509) *ca_stack = nullptr; + BIO *bio = BIO_new_file(client_cert.c_str(), "rb"); + PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr); + BIO_free_all(bio); + PKCS12_parse(p12, client_password.c_str(), &pkey, &cert, &ca_stack); + PKCS12_free(p12); + + if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) { + MS_LOG(EXCEPTION) << "Verify cert time failed."; + } + + if (!SSL_CTX_use_certificate(SSLWrapper::GetInstance().GetSSLCtx(false), cert)) { + MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; + } + + if (!SSL_CTX_use_PrivateKey(SSLWrapper::GetInstance().GetSSLCtx(false), pkey)) { + MS_LOG(EXCEPTION) << "SSL use private key file failed!"; + } + + std::string client_ca = kCAcrt; + std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath); + if (!CommUtil::IsFileExists(ca_path)) { + MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist."; + } + client_ca = ca_path; + MS_LOG(INFO) << "The ca cert path:" << client_ca; + + if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx(false))) { + MS_LOG(EXCEPTION) << "SSL check private key file failed!"; + } + + if (!SSL_CTX_load_verify_locations(SSLWrapper::GetInstance().GetSSLCtx(false), client_ca.c_str(), nullptr)) { + MS_LOG(EXCEPTION) << "SSL load ca location failed!"; + } + + SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(false), SSL_OP_NO_SSLv2); + + buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING, + BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); + return true; +} + void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) { MS_EXCEPTION_IF_NULL(bev); MS_EXCEPTION_IF_NULL(ptr); diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_client.h b/mindspore/ccsrc/ps/core/communicator/tcp_client.h index eb18f1e63c7..c73b1e7666b 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_client.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_client.h @@ -38,6 +38,7 @@ #include "ps/constants.h" #include "ps/ps_context.h" #include "ps/core/communicator/tcp_message_handler.h" +#include "ps/core/file_configuration.h" namespace mindspore { namespace ps { @@ -52,7 +53,7 @@ class TcpClient { std::function &, const Protos &, const void *, size_t size)>; using OnTimer = std::function; - explicit TcpClient(const std::string &address, std::uint16_t port); + explicit TcpClient(const std::string &address, std::uint16_t port, Configuration *config); virtual ~TcpClient(); std::string GetServerAddress() const; @@ -80,6 +81,7 @@ class TcpClient { virtual void OnReadHandler(const void *buf, size_t num); static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg); void NotifyConnected(); + bool EstablishSSL(); private: OnMessage message_callback_; @@ -104,6 +106,8 @@ class TcpClient { std::uint16_t server_port_; std::atomic is_stop_; std::atomic is_connected_; + // The Configuration file + Configuration *config_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_server.cc b/mindspore/ccsrc/ps/core/communicator/tcp_server.cc index 39bb36e6c9c..5acb741c5ee 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_server.cc +++ b/mindspore/ccsrc/ps/core/communicator/tcp_server.cc @@ -102,13 +102,14 @@ bool TcpConnection::SendMessage(const std::shared_ptr &meta, const return res; } -TcpServer::TcpServer(const std::string &address, std::uint16_t port) +TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *config) : base_(nullptr), signal_event_(nullptr), listener_(nullptr), server_address_(std::move(address)), server_port_(port), - is_stop_(true) {} + is_stop_(true), + config_(config) {} TcpServer::~TcpServer() { if (signal_event_ != nullptr) { @@ -288,18 +289,48 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE); } else { MS_LOG(INFO) << "Enable ssl support."; + if (server->config_ == nullptr) { + MS_LOG(EXCEPTION) << "The config is empty."; + } + SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx()); - X509 *cert = SSLWrapper::GetInstance().ReadCertFromFile(kCertificateChain); + // 1.Parse the server's certificate and the ciphertext of key. + std::string server_cert = kCertificateChain; + std::string path = CommUtil::ParseConfig(*(server->config_), kServerCertPath); + if (!CommUtil::IsFileExists(path)) { + MS_LOG(EXCEPTION) << "The key:" << kServerCertPath << "'s value is not exist."; + } + server_cert = path; + + MS_LOG(INFO) << "The server cert path:" << server_cert; + + // 2. Parse the server password. + std::string server_password = CommUtil::ParseConfig(*(server->config_), kServerPassword); + if (!server_password.empty()) { + MS_LOG(EXCEPTION) << "The key:" << kServerPassword << "'s value is empty."; + } + + MS_LOG(INFO) << "The server password:" << server_password; + + EVP_PKEY *pkey = nullptr; + X509 *cert = nullptr; + STACK_OF(X509) *ca_stack = nullptr; + BIO *bio = BIO_new_file(server_cert.c_str(), "rb"); + PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr); + BIO_free_all(bio); + PKCS12_parse(p12, server_password.c_str(), &pkey, &cert, &ca_stack); + PKCS12_free(p12); + if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) { MS_LOG(EXCEPTION) << "Verify cert time failed."; } - if (!SSL_CTX_use_certificate_chain_file(SSLWrapper::GetInstance().GetSSLCtx(), kCertificateChain)) { + if (!SSL_CTX_use_certificate(SSLWrapper::GetInstance().GetSSLCtx(), cert)) { MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!"; } - if (!SSL_CTX_use_PrivateKey_file(SSLWrapper::GetInstance().GetSSLCtx(), kPrivateKey, SSL_FILETYPE_PEM)) { + if (!SSL_CTX_use_PrivateKey(SSLWrapper::GetInstance().GetSSLCtx(), pkey)) { MS_LOG(EXCEPTION) << "SSL use private key file failed!"; } diff --git a/mindspore/ccsrc/ps/core/communicator/tcp_server.h b/mindspore/ccsrc/ps/core/communicator/tcp_server.h index 3f5c95afeb0..05bfd71640c 100644 --- a/mindspore/ccsrc/ps/core/communicator/tcp_server.h +++ b/mindspore/ccsrc/ps/core/communicator/tcp_server.h @@ -42,6 +42,7 @@ #include "ps/core/comm_util.h" #include "ps/constants.h" #include "ps/ps_context.h" +#include "ps/core/file_configuration.h" namespace mindspore { namespace ps { @@ -85,7 +86,7 @@ class TcpServer { using OnTimerOnce = std::function; using OnTimer = std::function; - TcpServer(const std::string &address, std::uint16_t port); + TcpServer(const std::string &address, std::uint16_t port, Configuration *config); TcpServer(const TcpServer &server); virtual ~TcpServer(); @@ -140,6 +141,8 @@ class TcpServer { OnServerReceiveMessage message_callback_; OnTimerOnce on_timer_once_callback_; OnTimer on_timer_callback_; + // The Configuration file + Configuration *config_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/configuration.h b/mindspore/ccsrc/ps/core/configuration.h index dcb55543ccd..9991be1fc57 100644 --- a/mindspore/ccsrc/ps/core/configuration.h +++ b/mindspore/ccsrc/ps/core/configuration.h @@ -42,6 +42,9 @@ class Configuration { // Initialize database connection or load config file. virtual bool Initialize() = 0; + // Determine whether the initialization has been completed. + virtual bool IsInitialized() const = 0; + // Get configuration data from database or config file. virtual std::string Get(const std::string &key, const std::string &defaultvalue) const = 0; diff --git a/mindspore/ccsrc/ps/core/file_configuration.cc b/mindspore/ccsrc/ps/core/file_configuration.cc index 1a01c7fa487..eaaf4104067 100644 --- a/mindspore/ccsrc/ps/core/file_configuration.cc +++ b/mindspore/ccsrc/ps/core/file_configuration.cc @@ -29,6 +29,7 @@ bool FileConfiguration::Initialize() { std::ifstream json_file(file_path_); json_file >> js; json_file.close(); + is_initialized_ = true; } catch (nlohmann::json::exception &e) { std::string illegal_exception = e.what(); MS_LOG(ERROR) << "Parse json file:" << file_path_ << " failed, the exception:" << illegal_exception; @@ -37,6 +38,8 @@ bool FileConfiguration::Initialize() { return true; } +bool FileConfiguration::IsInitialized() const { return is_initialized_.load(); } + std::string FileConfiguration::Get(const std::string &key, const std::string &defaultvalue) const { if (!js.contains(key)) { MS_LOG(WARNING) << "The key:" << key << " is not exist."; diff --git a/mindspore/ccsrc/ps/core/file_configuration.h b/mindspore/ccsrc/ps/core/file_configuration.h index 165bf424c7a..d5558967f32 100644 --- a/mindspore/ccsrc/ps/core/file_configuration.h +++ b/mindspore/ccsrc/ps/core/file_configuration.h @@ -47,11 +47,13 @@ namespace core { //} class FileConfiguration : public Configuration { public: - explicit FileConfiguration(const std::string &path) : file_path_(path) {} + explicit FileConfiguration(const std::string &path) : file_path_(path), is_initialized_(false) {} ~FileConfiguration() = default; bool Initialize() override; + bool IsInitialized() const override; + std::string Get(const std::string &key, const std::string &defaultvalue) const override; void Put(const std::string &key, const std::string &value) override; @@ -63,6 +65,8 @@ class FileConfiguration : public Configuration { std::string file_path_; nlohmann::json js; + + std::atomic is_initialized_; }; } // namespace core } // namespace ps diff --git a/mindspore/ccsrc/ps/core/scheduler_node.cc b/mindspore/ccsrc/ps/core/scheduler_node.cc index d9c04927e10..b79e372afea 100644 --- a/mindspore/ccsrc/ps/core/scheduler_node.cc +++ b/mindspore/ccsrc/ps/core/scheduler_node.cc @@ -76,6 +76,10 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr &server, } void SchedulerNode::Initialize() { + config_ = std::make_unique(PSContext::instance()->config_file_path()); + if (!config_->Initialize()) { + MS_LOG(INFO) << "The config file is empty."; + } InitCommandHandler(); CreateTcpServer(); is_already_stopped_ = false; @@ -101,7 +105,7 @@ void SchedulerNode::CreateTcpServer() { std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host; uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port; - server_ = std::make_shared(scheduler_host, scheduler_port); + server_ = std::make_shared(scheduler_host, scheduler_port, config_.get()); server_->SetMessageCallback([&](const std::shared_ptr &conn, const std::shared_ptr &meta, const Protos &, const void *data, size_t size) { if (handlers_.count(meta->cmd()) == 0) { @@ -413,9 +417,12 @@ const std::shared_ptr &SchedulerNode::GetOrCreateClient(const NodeInf if (connected_nodes_.count(node_info.node_id_)) { return connected_nodes_[node_info.node_id_]; } else { + if (config_ == nullptr) { + MS_LOG(EXCEPTION) << "The config is empty."; + } std::string ip = node_info.ip_; uint16_t port = node_info.port_; - auto client = std::make_shared(ip, port); + auto client = std::make_shared(ip, port, config_.get()); client->SetMessageCallback([&](std::shared_ptr meta, const Protos &protos, const void *data, size_t size) { NotifyMessageArrival(meta); }); client->Init(); diff --git a/mindspore/ccsrc/ps/core/server_node.cc b/mindspore/ccsrc/ps/core/server_node.cc index 290aaa82dcc..86a619ab627 100644 --- a/mindspore/ccsrc/ps/core/server_node.cc +++ b/mindspore/ccsrc/ps/core/server_node.cc @@ -58,7 +58,7 @@ void ServerNode::CreateTcpServer() { std::string interface; std::string server_ip; CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); - server_ = std::make_shared(server_ip, 0); + server_ = std::make_shared(server_ip, 0, config_.get()); server_->SetMessageCallback([&](const std::shared_ptr &conn, const std::shared_ptr &meta, const Protos &protos, const void *data, size_t size) { if (server_handler_.count(meta->cmd()) == 0) { diff --git a/mindspore/ccsrc/ps/core/worker_node.cc b/mindspore/ccsrc/ps/core/worker_node.cc index 43fa971cfc6..e0d891eee3f 100644 --- a/mindspore/ccsrc/ps/core/worker_node.cc +++ b/mindspore/ccsrc/ps/core/worker_node.cc @@ -68,7 +68,7 @@ void WorkerNode::CreateTcpServer() { std::string interface; std::string server_ip; CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip); - server_ = std::make_shared(server_ip, 0); + server_ = std::make_shared(server_ip, 0, config_.get()); server_->SetMessageCallback([&](std::shared_ptr conn, std::shared_ptr meta, const Protos &protos, const void *data, size_t size) { if (server_handler_.count(meta->cmd()) == 0) { diff --git a/tests/st/fl/hybrid_lenet/config.json b/tests/st/fl/hybrid_lenet/config.json index 37ac6edfb25..d5d1a93a6a9 100644 --- a/tests/st/fl/hybrid_lenet/config.json +++ b/tests/st/fl/hybrid_lenet/config.json @@ -2,5 +2,12 @@ "recovery": { "storge_type": 1, "storage_file_path": "recovery.json" - } + }, + "server_cert_path": "server.crt", + "server_password": "server_password", + "crl_path": "server.crl", + "client_cert_path": "client.crt", + "client_password": "client_password", + "ca_cert_path": "ca.crt", + "Key_encrypt_decrypt_algorithm": "" } \ No newline at end of file diff --git a/tests/st/fl/mobile/config.json b/tests/st/fl/mobile/config.json index 37ac6edfb25..d5d1a93a6a9 100644 --- a/tests/st/fl/mobile/config.json +++ b/tests/st/fl/mobile/config.json @@ -2,5 +2,12 @@ "recovery": { "storge_type": 1, "storage_file_path": "recovery.json" - } + }, + "server_cert_path": "server.crt", + "server_password": "server_password", + "crl_path": "server.crl", + "client_cert_path": "client.crt", + "client_password": "client_password", + "ca_cert_path": "ca.crt", + "Key_encrypt_decrypt_algorithm": "" } \ No newline at end of file diff --git a/tests/ut/cpp/ps/core/tcp_client_tests.cc b/tests/ut/cpp/ps/core/tcp_client_tests.cc index e8c9603275f..62d0c82d1b5 100644 --- a/tests/ut/cpp/ps/core/tcp_client_tests.cc +++ b/tests/ut/cpp/ps/core/tcp_client_tests.cc @@ -28,7 +28,8 @@ class TestTcpClient : public UT::Common { }; TEST_F(TestTcpClient, InitClientIPError) { - auto client = std::make_unique("127.0.0.13543", 9000); + std::unique_ptr config = std::make_unique(""); + auto client = std::make_unique("127.0.0.13543", 9000, config.get()); client->SetMessageCallback([&](std::shared_ptr, const Protos &, const void *data, size_t size) { CommMessage message; @@ -41,7 +42,8 @@ TEST_F(TestTcpClient, InitClientIPError) { } TEST_F(TestTcpClient, InitClientPortErrorNoException) { - auto client = std::make_unique("127.0.0.1", -1); + std::unique_ptr config = std::make_unique(""); + auto client = std::make_unique("127.0.0.1", -1, config.get()); client->SetMessageCallback([&](std::shared_ptr, const Protos &, const void *data, size_t size) { CommMessage message; diff --git a/tests/ut/cpp/ps/core/tcp_pb_server_test.cc b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc index 9daa85fd1d9..cf36a02093e 100644 --- a/tests/ut/cpp/ps/core/tcp_pb_server_test.cc +++ b/tests/ut/cpp/ps/core/tcp_pb_server_test.cc @@ -30,7 +30,8 @@ class TestTcpServer : public UT::Common { virtual ~TestTcpServer() = default; void SetUp() override { - server_ = std::make_unique("127.0.0.1", 0); + std::unique_ptr config = std::make_unique(""); + server_ = std::make_unique("127.0.0.1", 0, config.get()); std::unique_ptr http_server_thread_(nullptr); http_server_thread_ = std::make_unique([=]() { server_->SetMessageCallback([=](std::shared_ptr conn, std::shared_ptr meta, @@ -58,7 +59,8 @@ class TestTcpServer : public UT::Common { }; TEST_F(TestTcpServer, ServerSendMessage) { - client_ = std::make_unique("127.0.0.1", server_->BoundPort()); + std::unique_ptr config = std::make_unique(""); + client_ = std::make_unique("127.0.0.1", server_->BoundPort(), config.get()); std::cout << server_->BoundPort() << std::endl; std::unique_ptr http_client_thread(nullptr); http_client_thread = std::make_unique([&]() {