added check certificate

This commit is contained in:
chendongsheng 2021-08-12 11:02:32 +08:00
parent 671b05e034
commit b24f2e75d4
8 changed files with 398 additions and 6 deletions

View File

@ -24,6 +24,7 @@ if(NOT ENABLE_CPU OR WIN32)
list(REMOVE_ITEM _PS_SRC_FILES "parameter_server.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_request_handler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/ssl_wrapper.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/ssl_http.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/leader_scaler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/follower_scaler.cc")
list(REMOVE_ITEM _PS_SRC_FILES "core/file_configuration.cc")

View File

@ -133,6 +133,59 @@ constexpr char kClientCertPath[] = "client_cert_path";
constexpr char kClientPassword[] = "client_password";
constexpr char kCaCertPath[] = "ca_cert_path";
constexpr char kCipherList[] = "cipher_list";
constexpr char kCertCheckInterval[] = "cert_check_interval_in_hour";
// 7 * 24
constexpr int64_t kCertCheckIntervalInHour = 168;
constexpr char kCertExpireWarningTime[] = "cert_expire_warning_time_in_day";
// 90
constexpr int64_t kCertExpireWarningTimeInDay = 90;
constexpr char kConnectionNum[] = "connection_num";
constexpr int64_t kConnectionNumDefault = 10000;
constexpr char kLocalIp[] = "127.0.0.1";
constexpr int64_t kJanuary = 1;
constexpr int64_t kSeventyYear = 70;
constexpr int64_t kHundredYear = 100;
constexpr int64_t kThousandYear = 1000;
constexpr int64_t kBaseYear = 1900;
constexpr int64_t kMinWarningTime = 7;
constexpr int64_t kMaxWarningTime = 180;
constexpr char kServerCert[] = "server.p12";
constexpr char kClientCert[] = "client.p12";
constexpr char kCaCert[] = "ca.crt";
constexpr char kColon = ':';
const std::map<std::string, size_t> kCiphers = {{"ECDHE-RSA-AES128-GCM-SHA256", 0},
{"ECDHE-ECDSA-AES128-GCM-SHA256", 1},
{"ECDHE-RSA-AES256-GCM-SHA384", 2},
{"ECDHE-ECDSA-AES256-GCM-SHA384", 3},
{"DHE-RSA-AES128-GCM-SHA256", 4},
{"DHE-DSS-AES128-GCM-SHA256", 5},
{"ECDHE-RSA-AES128-SHA256", 6},
{"ECDHE-ECDSA-AES128-SHA256", 7},
{"ECDHE-RSA-AES128-SHA", 8},
{"ECDHE-ECDSA-AES128-SHA", 9},
{"ECDHE-RSA-AES256-SHA384", 10},
{"ECDHE-ECDSA-AES256-SHA384", 11},
{"ECDHE-RSA-AES256-SHA", 12},
{"ECDHE-ECDSA-AES256-SHA", 13},
{"DHE-RSA-AES128-SHA256", 14},
{"DHE-RSA-AES128-SHA", 15},
{"DHE-DSS-AES128-SHA256", 16},
{"DHE-RSA-AES256-SHA256", 17},
{"DHE-DSS-AES256-SHA", 18},
{"DHE-RSA-AES256-SHA", 19},
{"!aNULL", 20},
{"!eNULL", 21},
{"!EXPORT", 22},
{"!DES", 23},
{"!RC4", 24},
{"!3DES", 25},
{"!MD5", 26},
{"!PSK", 27},
{"kEDH+AESGCM", 28}};
using DataPtr = std::shared_ptr<unsigned char[]>;
using VectorPtr = std::shared_ptr<std::vector<unsigned char>>;
using Key = uint64_t;

View File

@ -64,7 +64,9 @@ void CommUtil::GetAvailableInterfaceAndIP(std::string *interface, std::string *i
interface->clear();
ip->clear();
getifaddrs(&if_address);
if (getifaddrs(&if_address) == -1) {
MS_LOG(WARNING) << "Get ifaddrs failed.";
}
for (ifa = if_address; ifa != nullptr; ifa = ifa->ifa_next) {
if (ifa->ifa_addr == nullptr) {
continue;
@ -146,6 +148,7 @@ bool CommUtil::Retry(const std::function<bool()> &func, size_t max_attempts, siz
}
void CommUtil::LogCallback(int severity, const char *msg) {
MS_EXCEPTION_IF_NULL(msg);
switch (severity) {
case EVENT_LOG_MSG:
MS_LOG(INFO) << kLibeventLogPrefix << msg;
@ -173,7 +176,11 @@ bool CommUtil::IsFileExists(const std::string &file) {
std::string CommUtil::ClusterStateToString(const ClusterState &state) {
MS_LOG(INFO) << "The cluster state:" << state;
return kClusterState.at(state);
if (state < SizeToInt(kClusterState.size())) {
return kClusterState.at(state);
} else {
return "";
}
}
std::string CommUtil::ParseConfig(const Configuration &config, const std::string &key) {
@ -190,6 +197,145 @@ std::string CommUtil::ParseConfig(const Configuration &config, const std::string
std::string path = config.GetString(key, "");
return path;
}
bool CommUtil::VerifyCertTime(const X509 *cert, int64_t time) {
MS_EXCEPTION_IF_NULL(cert);
ASN1_TIME *start = X509_getm_notBefore(cert);
ASN1_TIME *end = X509_getm_notAfter(cert);
MS_EXCEPTION_IF_NULL(start);
MS_EXCEPTION_IF_NULL(end);
int day = 0;
int sec = 0;
if (!ASN1_TIME_diff(&day, &sec, start, NULL)) {
MS_LOG(WARNING) << "ASN1 time diff failed.";
return false;
}
if (day < 0 || sec < 0) {
MS_LOG(WARNING) << "Cert start time is later than now time.";
return false;
}
day = 0;
sec = 0;
if (!ASN1_TIME_diff(&day, &sec, NULL, end)) {
MS_LOG(WARNING) << "ASN1 time diff failed.";
return false;
}
int64_t interval = kCertExpireWarningTimeInDay;
if (time > 0) {
interval = time;
}
if (day < LongToInt(interval) && day >= 0) {
MS_LOG(WARNING) << "The certificate will expire in " << day << " days and " << sec << " seconds.";
} else if (day < 0 || sec < 0) {
MS_LOG(WARNING) << "The certificate has expired.";
return false;
}
return true;
}
bool CommUtil::VerifyCRL(const X509 *cert, const std::string &crl_path) {
MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
BIO *bio = BIO_new_file(crl_path.c_str(), "r");
MS_ERROR_IF_NULL_W_RET_VAL(bio, false);
X509_CRL *root_crl = PEM_read_bio_X509_CRL(bio, nullptr, nullptr, nullptr);
MS_ERROR_IF_NULL_W_RET_VAL(root_crl, false);
EVP_PKEY *evp_pkey = X509_get_pubkey(const_cast<X509 *>(cert));
MS_ERROR_IF_NULL_W_RET_VAL(evp_pkey, false);
int ret = X509_CRL_verify(root_crl, evp_pkey);
BIO_free_all(bio);
if (ret == 1) {
MS_LOG(WARNING) << "Equip cert in root crl, verify failed";
return false;
}
MS_LOG(INFO) << "VerifyCRL success.";
return true;
}
bool CommUtil::VerifyCommonName(const X509 *cert, const std::string &ca_path) {
MS_ERROR_IF_NULL_W_RET_VAL(cert, false);
X509 *cert_temp = const_cast<X509 *>(cert);
char subject_cn[256] = "";
char issuer_cn[256] = "";
X509_NAME *subject_name = X509_get_subject_name(cert_temp);
X509_NAME *issuer_name = X509_get_issuer_name(cert_temp);
MS_ERROR_IF_NULL_W_RET_VAL(subject_name, false);
MS_ERROR_IF_NULL_W_RET_VAL(issuer_name, false);
if (!X509_NAME_get_text_by_NID(subject_name, NID_commonName, subject_cn, sizeof(subject_cn))) {
MS_LOG(WARNING) << "Get text by nid failed.";
return false;
}
if (!X509_NAME_get_text_by_NID(issuer_name, NID_commonName, issuer_cn, sizeof(issuer_cn))) {
MS_LOG(WARNING) << "Get text by nid failed.";
return false;
}
MS_LOG(INFO) << "the subject:" << subject_cn << ", the issuer:" << issuer_cn;
BIO *ca_bio = BIO_new_file(ca_path.c_str(), "r");
MS_EXCEPTION_IF_NULL(ca_bio);
X509 *ca_cert = PEM_read_bio_X509(ca_bio, nullptr, nullptr, nullptr);
MS_EXCEPTION_IF_NULL(ca_cert);
char ca_subject_cn[256] = "";
char ca_issuer_cn[256] = "";
X509_NAME *ca_subject_name = X509_get_subject_name(ca_cert);
X509_NAME *ca_issuer_name = X509_get_issuer_name(ca_cert);
MS_ERROR_IF_NULL_W_RET_VAL(ca_subject_name, false);
MS_ERROR_IF_NULL_W_RET_VAL(ca_issuer_name, false);
if (!X509_NAME_get_text_by_NID(ca_subject_name, NID_commonName, ca_subject_cn, sizeof(subject_cn))) {
MS_LOG(WARNING) << "Get text by nid failed.";
return false;
}
if (!X509_NAME_get_text_by_NID(ca_issuer_name, NID_commonName, ca_issuer_cn, sizeof(issuer_cn))) {
MS_LOG(WARNING) << "Get text by nid failed.";
return false;
}
MS_LOG(INFO) << "the subject:" << ca_subject_cn << ", the issuer:" << ca_issuer_cn;
BIO_free_all(ca_bio);
if (strcmp(issuer_cn, ca_subject_cn) != 0) {
return false;
}
return true;
}
std::vector<std::string> CommUtil::Split(const std::string &s, char delim) {
std::vector<std::string> res;
std::stringstream ss(s);
std::string item;
while (getline(ss, item, delim)) {
res.push_back(item);
}
return res;
}
bool CommUtil::VerifyCipherList(const std::vector<std::string> &list) {
for (auto &item : list) {
if (!kCiphers.count(item)) {
MS_LOG(WARNING) << "The ciphter:" << item << " is not supported.";
return false;
}
}
return true;
}
void CommUtil::InitOpenSSLEnv() {
if (!SSL_library_init()) {
MS_LOG(EXCEPTION) << "SSL_library_init failed.";
}
if (!ERR_load_crypto_strings()) {
MS_LOG(EXCEPTION) << "ERR_load_crypto_strings failed.";
}
if (!SSL_load_error_strings()) {
MS_LOG(EXCEPTION) << "SSL_load_error_strings failed.";
}
if (!OpenSSL_add_all_algorithms()) {
MS_LOG(EXCEPTION) << "OpenSSL_add_all_algorithms failed.";
}
}
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -37,6 +37,14 @@
#include <event2/listener.h>
#include <event2/util.h>
#include <openssl/ssl.h>
#include <openssl/rand.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <assert.h>
#include <openssl/pkcs12.h>
#include <openssl/bio.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
@ -49,6 +57,7 @@
#include <fstream>
#include <iostream>
#include <vector>
#include <algorithm>
#include "proto/comm.pb.h"
#include "proto/ps.pb.h"
@ -104,6 +113,18 @@ class CommUtil {
// Parse the configuration file according to the key.
static std::string ParseConfig(const Configuration &config, const std::string &key);
// verify valid of certificate time
static bool VerifyCertTime(const X509 *cert, int64_t time = 0);
// verify valid of equip certificate with CRL
static bool VerifyCRL(const X509 *cert, const std::string &crl_path);
// Check the common name of the certificate
static bool VerifyCommonName(const X509 *cert, const std::string &ca_path);
// The string is divided according to delim
static std::vector<std::string> Split(const std::string &s, char delim);
// Check the cipher list of the certificate
static bool VerifyCipherList(const std::vector<std::string> &list);
static void InitOpenSSLEnv();
private:
static std::random_device rd;
static std::mt19937_64 gen;

View File

@ -0,0 +1,111 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/core/communicator/ssl_http.h"
#include <sys/time.h>
#include <openssl/pem.h>
#include <openssl/sha.h>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <vector>
#include <iomanip>
#include <sstream>
namespace mindspore {
namespace ps {
namespace core {
SSLHTTP::SSLHTTP() : ssl_ctx_(nullptr) { InitSSL(); }
SSLHTTP::~SSLHTTP() { CleanSSL(); }
void SSLHTTP::InitSSL() {
CommUtil::InitOpenSSLEnv();
ssl_ctx_ = SSL_CTX_new(SSLv23_server_method());
if (!ssl_ctx_) {
MS_LOG(EXCEPTION) << "SSL_CTX_new failed";
}
X509_STORE *store = SSL_CTX_get_cert_store(ssl_ctx_);
MS_EXCEPTION_IF_NULL(store);
if (X509_STORE_set_default_paths(store) != 1) {
MS_LOG(EXCEPTION) << "X509_STORE_set_default_paths failed";
}
std::unique_ptr<Configuration> config_ =
std::make_unique<FileConfiguration>(PSContext::instance()->config_file_path());
MS_EXCEPTION_IF_NULL(config_);
if (!config_->Initialize()) {
MS_LOG(EXCEPTION) << "The config file is empty.";
}
// 1.Parse the server's certificate and the ciphertext of key.
std::string server_cert = kCertificateChain;
std::string path = CommUtil::ParseConfig(*(config_), kServerCertPath);
if (!CommUtil::IsFileExists(path)) {
MS_LOG(EXCEPTION) << "The key:" << kServerCertPath << "'s value is not exist.";
}
server_cert = path;
// 2. Parse the server password.
std::string server_password = CommUtil::ParseConfig(*(config_), kServerPassword);
if (server_password.empty()) {
MS_LOG(EXCEPTION) << "The client password's value is empty.";
}
EVP_PKEY *pkey = nullptr;
X509 *cert = nullptr;
STACK_OF(X509) *ca_stack = nullptr;
BIO *bio = BIO_new_file(server_cert.c_str(), "rb");
MS_EXCEPTION_IF_NULL(bio);
PKCS12 *p12 = d2i_PKCS12_bio(bio, nullptr);
MS_EXCEPTION_IF_NULL(p12);
BIO_free_all(bio);
if (!PKCS12_parse(p12, server_password.c_str(), &pkey, &cert, &ca_stack)) {
MS_LOG(EXCEPTION) << "PKCS12_parse failed.";
}
PKCS12_free(p12);
std::string default_cipher_list = CommUtil::ParseConfig(*config_, kCipherList);
if (!SSL_CTX_set_cipher_list(ssl_ctx_, default_cipher_list.c_str())) {
MS_LOG(EXCEPTION) << "SSL use set cipher list failed!";
}
if (!SSL_CTX_use_certificate(ssl_ctx_, cert)) {
MS_LOG(EXCEPTION) << "SSL use certificate chain file failed!";
}
if (!SSL_CTX_use_PrivateKey(ssl_ctx_, pkey)) {
MS_LOG(EXCEPTION) << "SSL use private key file failed!";
}
if (!SSL_CTX_check_private_key(ssl_ctx_)) {
MS_LOG(EXCEPTION) << "SSL check private key file failed!";
}
}
void SSLHTTP::CleanSSL() {
if (ssl_ctx_ != nullptr) {
SSL_CTX_free(ssl_ctx_);
}
ERR_free_strings();
EVP_cleanup();
ERR_remove_thread_state(nullptr);
CRYPTO_cleanup_all_ex_data();
}
SSL_CTX *SSLHTTP::GetSSLCtx() const { return ssl_ctx_; }
} // namespace core
} // namespace ps
} // namespace mindspore

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_SSL_HTTP_H_
#define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_SSL_HTTP_H_
#include <openssl/ssl.h>
#include <openssl/rand.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <assert.h>
#include <openssl/pkcs12.h>
#include <openssl/bio.h>
#include <iostream>
#include <string>
#include <memory>
#include "utils/log_adapter.h"
#include "ps/core/comm_util.h"
#include "ps/constants.h"
#include "ps/core/file_configuration.h"
namespace mindspore {
namespace ps {
namespace core {
class SSLHTTP {
public:
static SSLHTTP &GetInstance() {
static SSLHTTP instance;
return instance;
}
SSL_CTX *GetSSLCtx() const;
private:
SSLHTTP();
virtual ~SSLHTTP();
SSLHTTP(const SSLHTTP &) = delete;
SSLHTTP &operator=(const SSLHTTP &) = delete;
void InitSSL();
void CleanSSL();
SSL_CTX *ssl_ctx_;
};
} // namespace core
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_SSL_HTTP_H_

View File

@ -44,10 +44,7 @@ SSLWrapper::SSLWrapper()
SSLWrapper::~SSLWrapper() { CleanSSL(); }
void SSLWrapper::InitSSL() {
SSL_library_init();
ERR_load_crypto_strings();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
CommUtil::InitOpenSSLEnv();
int rand = RAND_poll();
if (rand == 0) {
MS_LOG(ERROR) << "RAND_poll failed";

View File

@ -29,6 +29,7 @@
#include <string>
#include "utils/log_adapter.h"
#include "ps/core/comm_util.h"
namespace mindspore {
namespace ps {