added cert

This commit is contained in:
chendongsheng 2021-07-14 15:53:17 +08:00
parent 0e025d2210
commit dc3d6dc915
19 changed files with 206 additions and 50 deletions

View File

@ -108,6 +108,13 @@ constexpr char kRecoveryServerNum[] = "server_num";
constexpr char kRecoverySchedulerIp[] = "scheduler_ip";
constexpr char kRecoverySchedulerPort[] = "scheduler_port";
constexpr char kServerCertPath[] = "server_cert_path";
constexpr char kServerPassword[] = "server_password";
constexpr char kCrlPath[] = "crl_path";
constexpr char kClientCertPath[] = "client_cert_path";
constexpr char kClientPassword[] = "client_password";
constexpr char kCaCertPath[] = "ca_cert_path";
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
using Key = uint64_t;

View File

@ -705,7 +705,11 @@ bool AbstractNode::WaitForDisconnect(const uint32_t &timeout) {
}
bool AbstractNode::InitClientToScheduler() {
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_);
if (config_ == nullptr) {
MS_LOG(WARNING) << "The config is empty.";
return false;
}
client_to_scheduler_ = std::make_shared<TcpClient>(scheduler_ip_, scheduler_port_, config_.get());
client_to_scheduler_->SetMessageCallback(
[&](const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data, size_t size) {
try {
@ -750,9 +754,12 @@ const std::shared_ptr<TcpClient> &AbstractNode::GetOrCreateTcpClient(const uint3
if (nodes_address_.find(std::make_pair(NodeRole::SERVER, rank_id)) == nodes_address_.end()) {
MS_LOG(EXCEPTION) << "Worker receive nodes info from scheduler failed!";
}
if (config_ == nullptr) {
MS_LOG(EXCEPTION) << "The config is empty.";
}
std::string ip = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].first;
uint16_t port = nodes_address_[std::make_pair(NodeRole::SERVER, rank_id)].second;
auto client = std::make_shared<TcpClient>(ip, port);
auto client = std::make_shared<TcpClient>(ip, port, config_.get());
client->SetMessageCallback([&](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data,
size_t size) {
switch (meta->cmd()) {

View File

@ -175,6 +175,21 @@ std::string CommUtil::ClusterStateToString(const ClusterState &state) {
MS_LOG(INFO) << "The cluster state:" << state;
return kClusterState.at(state);
}
std::string CommUtil::ParseConfig(const Configuration &config, const std::string &key) {
if (!config.IsInitialized()) {
MS_LOG(INFO) << "The config is not initialized.";
return "";
}
if (!const_cast<Configuration &>(config).Exists(key)) {
MS_LOG(INFO) << "The key:" << key << " is not exist.";
return "";
}
std::string path = config.Get(key, "");
return path;
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -57,6 +57,7 @@
#include "utils/log_adapter.h"
#include "ps/ps_context.h"
#include "utils/convert_utils_base.h"
#include "ps/core/configuration.h"
namespace mindspore {
namespace ps {
@ -100,6 +101,9 @@ class CommUtil {
// Convert cluster state to string when response the http request.
static std::string ClusterStateToString(const ClusterState &state);
// Parse the configuration file according to the key.
static std::string ParseConfig(const Configuration &config, const std::string &key);
private:
static std::random_device rd;
static std::mt19937_64 gen;

View File

@ -22,6 +22,8 @@
#include <openssl/err.h>
#include <openssl/evp.h>
#include <assert.h>
#include <openssl/pkcs12.h>
#include <openssl/bio.h>
#include <iostream>
#include <string>

View File

@ -37,13 +37,14 @@ event_base *TcpClient::event_base_ = nullptr;
std::mutex TcpClient::event_base_mutex_;
bool TcpClient::is_started_ = false;
TcpClient::TcpClient(const std::string &address, std::uint16_t port)
TcpClient::TcpClient(const std::string &address, std::uint16_t port, Configuration *config)
: event_timeout_(nullptr),
buffer_event_(nullptr),
server_address_(std::move(address)),
server_port_(port),
is_stop_(true),
is_connected_(false) {
is_connected_(false),
config_(config) {
message_handler_.SetCallback(
[this](const std::shared_ptr<MessageMeta> &meta, const Protos &protos, const void *data, size_t size) {
if (message_callback_) {
@ -109,35 +110,9 @@ void TcpClient::Init() {
if (!PSContext::instance()->enable_ssl()) {
buffer_event_ = bufferevent_socket_new(event_base_, -1, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
} else {
MS_LOG(INFO) << "Enable ssl support.";
SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx(false));
X509 *cert = SSLWrapper::GetInstance().ReadCertFromFile(kCertificateChain);
if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) {
MS_LOG(EXCEPTION) << "Verify cert time failed.";
if (!EstablishSSL()) {
MS_LOG(EXCEPTION) << "Establish SSL failed.";
}
if (!SSL_CTX_use_certificate_chain_file(SSLWrapper::GetInstance().GetSSLCtx(false), kCertificateChain)) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
}
if (!SSL_CTX_use_PrivateKey_file(SSLWrapper::GetInstance().GetSSLCtx(false), kPrivateKey, SSL_FILETYPE_PEM)) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
}
if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx(false))) {
MS_LOG(EXCEPTION) << "SSL check private key file failed!";
}
if (!SSL_CTX_load_verify_locations(SSLWrapper::GetInstance().GetSSLCtx(false), kCAcrt, nullptr)) {
MS_LOG(EXCEPTION) << "SSL load ca location failed!";
}
SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(false), SSL_OP_NO_SSLv2);
buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING,
BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
}
MS_EXCEPTION_IF_NULL(buffer_event_);
@ -229,6 +204,79 @@ void TcpClient::NotifyConnected() {
connection_cond_.notify_all();
}
bool TcpClient::EstablishSSL() {
MS_LOG(INFO) << "Enable ssl support.";
if (config_ == nullptr) {
MS_LOG(EXCEPTION) << "The config is empty.";
}
SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx(false));
// 1.Parse the client's certificate and the ciphertext of key.
std::string client_cert = kCertificateChain;
std::string path = CommUtil::ParseConfig(*config_, kClientCertPath);
if (!CommUtil::IsFileExists(path)) {
MS_LOG(WARNING) << "The key:" << kClientCertPath << "'s value is not exist.";
return false;
}
client_cert = path;
MS_LOG(INFO) << "The client cert path:" << client_cert;
// 2. Parse the client password.
std::string client_password = CommUtil::ParseConfig(*config_, kClientPassword);
if (!client_password.empty()) {
MS_LOG(WARNING) << "The key:" << kClientPassword << "'s value is empty.";
return false;
}
MS_LOG(INFO) << "The client password:" << client_password;
EVP_PKEY *pkey = nullptr;
X509 *cert = nullptr;
STACK_OF(X509) *ca_stack = nullptr;
BIO *bio = BIO_new_file(client_cert.c_str(), "rb");
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
BIO_free_all(bio);
PKCS12_parse(p12, client_password.c_str(), &pkey, &cert, &ca_stack);
PKCS12_free(p12);
if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) {
MS_LOG(EXCEPTION) << "Verify cert time failed.";
}
if (!SSL_CTX_use_certificate(SSLWrapper::GetInstance().GetSSLCtx(false), cert)) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
}
if (!SSL_CTX_use_PrivateKey(SSLWrapper::GetInstance().GetSSLCtx(false), pkey)) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
}
std::string client_ca = kCAcrt;
std::string ca_path = CommUtil::ParseConfig(*config_, kCaCertPath);
if (!CommUtil::IsFileExists(ca_path)) {
MS_LOG(WARNING) << "The key:" << kCaCertPath << "'s value is not exist.";
}
client_ca = ca_path;
MS_LOG(INFO) << "The ca cert path:" << client_ca;
if (!SSL_CTX_check_private_key(SSLWrapper::GetInstance().GetSSLCtx(false))) {
MS_LOG(EXCEPTION) << "SSL check private key file failed!";
}
if (!SSL_CTX_load_verify_locations(SSLWrapper::GetInstance().GetSSLCtx(false), client_ca.c_str(), nullptr)) {
MS_LOG(EXCEPTION) << "SSL load ca location failed!";
}
SSL_CTX_set_options(SSLWrapper::GetInstance().GetSSLCtx(false), SSL_OP_NO_SSLv2);
buffer_event_ = bufferevent_openssl_socket_new(event_base_, -1, ssl, BUFFEREVENT_SSL_CONNECTING,
BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
return true;
}
void TcpClient::EventCallback(struct bufferevent *bev, std::int16_t events, void *ptr) {
MS_EXCEPTION_IF_NULL(bev);
MS_EXCEPTION_IF_NULL(ptr);

View File

@ -38,6 +38,7 @@
#include "ps/constants.h"
#include "ps/ps_context.h"
#include "ps/core/communicator/tcp_message_handler.h"
#include "ps/core/file_configuration.h"
namespace mindspore {
namespace ps {
@ -52,7 +53,7 @@ class TcpClient {
std::function<void(const std::shared_ptr<MessageMeta> &, const Protos &, const void *, size_t size)>;
using OnTimer = std::function<void()>;
explicit TcpClient(const std::string &address, std::uint16_t port);
explicit TcpClient(const std::string &address, std::uint16_t port, Configuration *config);
virtual ~TcpClient();
std::string GetServerAddress() const;
@ -80,6 +81,7 @@ class TcpClient {
virtual void OnReadHandler(const void *buf, size_t num);
static void TimerCallback(evutil_socket_t fd, int16_t event, void *arg);
void NotifyConnected();
bool EstablishSSL();
private:
OnMessage message_callback_;
@ -104,6 +106,8 @@ class TcpClient {
std::uint16_t server_port_;
std::atomic<bool> is_stop_;
std::atomic<bool> is_connected_;
// The Configuration file
Configuration *config_;
};
} // namespace core
} // namespace ps

View File

@ -102,13 +102,14 @@ bool TcpConnection::SendMessage(const std::shared_ptr<MessageMeta> &meta, const
return res;
}
TcpServer::TcpServer(const std::string &address, std::uint16_t port)
TcpServer::TcpServer(const std::string &address, std::uint16_t port, Configuration *config)
: base_(nullptr),
signal_event_(nullptr),
listener_(nullptr),
server_address_(std::move(address)),
server_port_(port),
is_stop_(true) {}
is_stop_(true),
config_(config) {}
TcpServer::~TcpServer() {
if (signal_event_ != nullptr) {
@ -288,18 +289,48 @@ void TcpServer::ListenerCallback(struct evconnlistener *, evutil_socket_t fd, st
bev = bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_THREADSAFE);
} else {
MS_LOG(INFO) << "Enable ssl support.";
if (server->config_ == nullptr) {
MS_LOG(EXCEPTION) << "The config is empty.";
}
SSL *ssl = SSL_new(SSLWrapper::GetInstance().GetSSLCtx());
X509 *cert = SSLWrapper::GetInstance().ReadCertFromFile(kCertificateChain);
// 1.Parse the server's certificate and the ciphertext of key.
std::string server_cert = kCertificateChain;
std::string path = CommUtil::ParseConfig(*(server->config_), kServerCertPath);
if (!CommUtil::IsFileExists(path)) {
MS_LOG(EXCEPTION) << "The key:" << kServerCertPath << "'s value is not exist.";
}
server_cert = path;
MS_LOG(INFO) << "The server cert path:" << server_cert;
// 2. Parse the server password.
std::string server_password = CommUtil::ParseConfig(*(server->config_), kServerPassword);
if (!server_password.empty()) {
MS_LOG(EXCEPTION) << "The key:" << kServerPassword << "'s value is empty.";
}
MS_LOG(INFO) << "The server password:" << server_password;
EVP_PKEY *pkey = nullptr;
X509 *cert = nullptr;
STACK_OF(X509) *ca_stack = nullptr;
BIO *bio = BIO_new_file(server_cert.c_str(), "rb");
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
BIO_free_all(bio);
PKCS12_parse(p12, server_password.c_str(), &pkey, &cert, &ca_stack);
PKCS12_free(p12);
if (!SSLWrapper::GetInstance().VerifyCertTime(cert)) {
MS_LOG(EXCEPTION) << "Verify cert time failed.";
}
if (!SSL_CTX_use_certificate_chain_file(SSLWrapper::GetInstance().GetSSLCtx(), kCertificateChain)) {
if (!SSL_CTX_use_certificate(SSLWrapper::GetInstance().GetSSLCtx(), cert)) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
}
if (!SSL_CTX_use_PrivateKey_file(SSLWrapper::GetInstance().GetSSLCtx(), kPrivateKey, SSL_FILETYPE_PEM)) {
if (!SSL_CTX_use_PrivateKey(SSLWrapper::GetInstance().GetSSLCtx(), pkey)) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
}

View File

@ -42,6 +42,7 @@
#include "ps/core/comm_util.h"
#include "ps/constants.h"
#include "ps/ps_context.h"
#include "ps/core/file_configuration.h"
namespace mindspore {
namespace ps {
@ -85,7 +86,7 @@ class TcpServer {
using OnTimerOnce = std::function<void(const TcpServer &)>;
using OnTimer = std::function<void()>;
TcpServer(const std::string &address, std::uint16_t port);
TcpServer(const std::string &address, std::uint16_t port, Configuration *config);
TcpServer(const TcpServer &server);
virtual ~TcpServer();
@ -140,6 +141,8 @@ class TcpServer {
OnServerReceiveMessage message_callback_;
OnTimerOnce on_timer_once_callback_;
OnTimer on_timer_callback_;
// The Configuration file
Configuration *config_;
};
} // namespace core
} // namespace ps

View File

@ -42,6 +42,9 @@ class Configuration {
// Initialize database connection or load config file.
virtual bool Initialize() = 0;
// Determine whether the initialization has been completed.
virtual bool IsInitialized() const = 0;
// Get configuration data from database or config file.
virtual std::string Get(const std::string &key, const std::string &defaultvalue) const = 0;

View File

@ -29,6 +29,7 @@ bool FileConfiguration::Initialize() {
std::ifstream json_file(file_path_);
json_file >> js;
json_file.close();
is_initialized_ = true;
} catch (nlohmann::json::exception &e) {
std::string illegal_exception = e.what();
MS_LOG(ERROR) << "Parse json file:" << file_path_ << " failed, the exception:" << illegal_exception;
@ -37,6 +38,8 @@ bool FileConfiguration::Initialize() {
return true;
}
bool FileConfiguration::IsInitialized() const { return is_initialized_.load(); }
std::string FileConfiguration::Get(const std::string &key, const std::string &defaultvalue) const {
if (!js.contains(key)) {
MS_LOG(WARNING) << "The key:" << key << " is not exist.";

View File

@ -47,11 +47,13 @@ namespace core {
//}
class FileConfiguration : public Configuration {
public:
explicit FileConfiguration(const std::string &path) : file_path_(path) {}
explicit FileConfiguration(const std::string &path) : file_path_(path), is_initialized_(false) {}
~FileConfiguration() = default;
bool Initialize() override;
bool IsInitialized() const override;
std::string Get(const std::string &key, const std::string &defaultvalue) const override;
void Put(const std::string &key, const std::string &value) override;
@ -63,6 +65,8 @@ class FileConfiguration : public Configuration {
std::string file_path_;
nlohmann::json js;
std::atomic<bool> is_initialized_;
};
} // namespace core
} // namespace ps

View File

@ -76,6 +76,10 @@ void SchedulerNode::ProcessHeartbeat(const std::shared_ptr<TcpServer> &server,
}
void SchedulerNode::Initialize() {
config_ = std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
if (!config_->Initialize()) {
MS_LOG(INFO) << "The config file is empty.";
}
InitCommandHandler();
CreateTcpServer();
is_already_stopped_ = false;
@ -101,7 +105,7 @@ void SchedulerNode::CreateTcpServer() {
std::string scheduler_host = PSContext::instance()->cluster_config().scheduler_host;
uint32_t scheduler_port = PSContext::instance()->cluster_config().scheduler_port;
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port);
server_ = std::make_shared<TcpServer>(scheduler_host, scheduler_port, config_.get());
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &, const void *data, size_t size) {
if (handlers_.count(meta->cmd()) == 0) {
@ -413,9 +417,12 @@ const std::shared_ptr<TcpClient> &SchedulerNode::GetOrCreateClient(const NodeInf
if (connected_nodes_.count(node_info.node_id_)) {
return connected_nodes_[node_info.node_id_];
} else {
if (config_ == nullptr) {
MS_LOG(EXCEPTION) << "The config is empty.";
}
std::string ip = node_info.ip_;
uint16_t port = node_info.port_;
auto client = std::make_shared<TcpClient>(ip, port);
auto client = std::make_shared<TcpClient>(ip, port, config_.get());
client->SetMessageCallback([&](std::shared_ptr<MessageMeta> meta, const Protos &protos, const void *data,
size_t size) { NotifyMessageArrival(meta); });
client->Init();

View File

@ -58,7 +58,7 @@ void ServerNode::CreateTcpServer() {
std::string interface;
std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0);
server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get());
server_->SetMessageCallback([&](const std::shared_ptr<TcpConnection> &conn, const std::shared_ptr<MessageMeta> &meta,
const Protos &protos, const void *data, size_t size) {
if (server_handler_.count(meta->cmd()) == 0) {

View File

@ -68,7 +68,7 @@ void WorkerNode::CreateTcpServer() {
std::string interface;
std::string server_ip;
CommUtil::GetAvailableInterfaceAndIP(&interface, &server_ip);
server_ = std::make_shared<TcpServer>(server_ip, 0);
server_ = std::make_shared<TcpServer>(server_ip, 0, config_.get());
server_->SetMessageCallback([&](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
const Protos &protos, const void *data, size_t size) {
if (server_handler_.count(meta->cmd()) == 0) {

View File

@ -2,5 +2,12 @@
"recovery": {
"storge_type": 1,
"storage_file_path": "recovery.json"
}
},
"server_cert_path": "server.crt",
"server_password": "server_password",
"crl_path": "server.crl",
"client_cert_path": "client.crt",
"client_password": "client_password",
"ca_cert_path": "ca.crt",
"Key_encrypt_decrypt_algorithm": ""
}

View File

@ -2,5 +2,12 @@
"recovery": {
"storge_type": 1,
"storage_file_path": "recovery.json"
}
},
"server_cert_path": "server.crt",
"server_password": "server_password",
"crl_path": "server.crl",
"client_cert_path": "client.crt",
"client_password": "client_password",
"ca_cert_path": "ca.crt",
"Key_encrypt_decrypt_algorithm": ""
}

View File

@ -28,7 +28,8 @@ class TestTcpClient : public UT::Common {
};
TEST_F(TestTcpClient, InitClientIPError) {
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000);
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
auto client = std::make_unique<TcpClient>("127.0.0.13543", 9000, config.get());
client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;
@ -41,7 +42,8 @@ TEST_F(TestTcpClient, InitClientIPError) {
}
TEST_F(TestTcpClient, InitClientPortErrorNoException) {
auto client = std::make_unique<TcpClient>("127.0.0.1", -1);
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
auto client = std::make_unique<TcpClient>("127.0.0.1", -1, config.get());
client->SetMessageCallback([&](std::shared_ptr<MessageMeta>, const Protos &, const void *data, size_t size) {
CommMessage message;

View File

@ -30,7 +30,8 @@ class TestTcpServer : public UT::Common {
virtual ~TestTcpServer() = default;
void SetUp() override {
server_ = std::make_unique<TcpServer>("127.0.0.1", 0);
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
server_ = std::make_unique<TcpServer>("127.0.0.1", 0, config.get());
std::unique_ptr<std::thread> http_server_thread_(nullptr);
http_server_thread_ = std::make_unique<std::thread>([=]() {
server_->SetMessageCallback([=](std::shared_ptr<TcpConnection> conn, std::shared_ptr<MessageMeta> meta,
@ -58,7 +59,8 @@ class TestTcpServer : public UT::Common {
};
TEST_F(TestTcpServer, ServerSendMessage) {
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort());
std::unique_ptr<Configuration> config = std::make_unique<FileConfiguration>("");
client_ = std::make_unique<TcpClient>("127.0.0.1", server_->BoundPort(), config.get());
std::cout << server_->BoundPort() << std::endl;
std::unique_ptr<std::thread> http_client_thread(nullptr);
http_client_thread = std::make_unique<std::thread>([&]() {