switched SSL implementation to use boost ssl
This commit is contained in:
parent
1ed3ba7170
commit
84853dd1fd
|
@ -66,14 +66,9 @@ using std::max;
|
|||
using std::min;
|
||||
using std::pair;
|
||||
|
||||
NetworkOptions networkOptions;
|
||||
Reference<TLSOptions> tlsOptions;
|
||||
#define CERT_FILE_MAX_SIZE (5 * 1024 * 1024)
|
||||
|
||||
static void initTLSOptions() {
|
||||
if (!tlsOptions) {
|
||||
tlsOptions = Reference<TLSOptions>(new TLSOptions());
|
||||
}
|
||||
}
|
||||
NetworkOptions networkOptions;
|
||||
|
||||
static const Key CLIENT_LATENCY_INFO_PREFIX = LiteralStringRef("client_latency/");
|
||||
static const Key CLIENT_LATENCY_INFO_CTR_PREFIX = LiteralStringRef("client_latency_counter/");
|
||||
|
@ -887,43 +882,44 @@ void setNetworkOption(FDBNetworkOptions::Option option, Optional<StringRef> valu
|
|||
break;
|
||||
case FDBNetworkOptions::TLS_CERT_PATH:
|
||||
validateOptionValue(value, true);
|
||||
initTLSOptions();
|
||||
tlsOptions->set_cert_file( value.get().toString() );
|
||||
networkOptions.sslContext.use_certificate_chain_file(value.get().toString());
|
||||
break;
|
||||
case FDBNetworkOptions::TLS_CERT_BYTES:
|
||||
initTLSOptions();
|
||||
tlsOptions->set_cert_data( value.get().toString() );
|
||||
break;
|
||||
case FDBNetworkOptions::TLS_CA_PATH:
|
||||
case FDBNetworkOptions::TLS_CERT_BYTES: {
|
||||
validateOptionValue(value, true);
|
||||
initTLSOptions();
|
||||
tlsOptions->set_ca_file( value.get().toString() );
|
||||
std::string cert = value.get().toString();
|
||||
networkOptions.sslContext.use_certificate(boost::asio::buffer(cert.data(), cert.size()), boost::asio::ssl::context::pem);
|
||||
break;
|
||||
case FDBNetworkOptions::TLS_CA_BYTES:
|
||||
}
|
||||
case FDBNetworkOptions::TLS_CA_PATH: {
|
||||
validateOptionValue(value, true);
|
||||
initTLSOptions();
|
||||
tlsOptions->set_ca_data(value.get().toString());
|
||||
std::string cert = readFileBytes(value.get().toString(), CERT_FILE_MAX_SIZE);
|
||||
networkOptions.sslContext.add_certificate_authority(boost::asio::buffer(cert.data(), cert.size()));
|
||||
break;
|
||||
}
|
||||
case FDBNetworkOptions::TLS_CA_BYTES: {
|
||||
validateOptionValue(value, true);
|
||||
std::string cert = value.get().toString();
|
||||
networkOptions.sslContext.add_certificate_authority(boost::asio::buffer(cert.data(), cert.size()));
|
||||
break;
|
||||
}
|
||||
case FDBNetworkOptions::TLS_PASSWORD:
|
||||
validateOptionValue(value, true);
|
||||
initTLSOptions();
|
||||
tlsOptions->set_key_password(value.get().toString());
|
||||
networkOptions.tlsPassword = value.get().toString();
|
||||
break;
|
||||
case FDBNetworkOptions::TLS_KEY_PATH:
|
||||
validateOptionValue(value, true);
|
||||
initTLSOptions();
|
||||
tlsOptions->set_key_file( value.get().toString() );
|
||||
validateOptionValue(value, true);
|
||||
networkOptions.sslContext.use_private_key_file(value.get().toString(), boost::asio::ssl::context::pem);
|
||||
break;
|
||||
case FDBNetworkOptions::TLS_KEY_BYTES:
|
||||
case FDBNetworkOptions::TLS_KEY_BYTES: {
|
||||
validateOptionValue(value, true);
|
||||
initTLSOptions();
|
||||
tlsOptions->set_key_data( value.get().toString() );
|
||||
std::string cert = value.get().toString();
|
||||
networkOptions.sslContext.use_private_key(boost::asio::buffer(cert.data(), cert.size()), boost::asio::ssl::context::pem);
|
||||
break;
|
||||
}
|
||||
case FDBNetworkOptions::TLS_VERIFY_PEERS:
|
||||
validateOptionValue(value, true);
|
||||
initTLSOptions();
|
||||
try {
|
||||
tlsOptions->set_verify_peers({ value.get().toString() });
|
||||
//tlsOptions->set_verify_peers({ value.get().toString() }); FIXME
|
||||
} catch( Error& e ) {
|
||||
TraceEvent(SevWarnAlways, "TLSValidationSetError")
|
||||
.error( e )
|
||||
|
@ -987,15 +983,9 @@ void setupNetwork(uint64_t transportId, bool useMetrics) {
|
|||
if (!networkOptions.logClientInfo.present())
|
||||
networkOptions.logClientInfo = true;
|
||||
|
||||
g_network = newNet2(false, useMetrics || networkOptions.traceDirectory.present());
|
||||
g_network = newNet2(false, useMetrics || networkOptions.traceDirectory.present(), &networkOptions.sslContext, networkOptions.tlsPassword);
|
||||
FlowTransport::createInstance(true, transportId);
|
||||
Net2FileSystem::newFileSystem();
|
||||
|
||||
initTLSOptions();
|
||||
|
||||
#ifndef TLS_DISABLED
|
||||
tlsOptions->register_network();
|
||||
#endif
|
||||
}
|
||||
|
||||
void runNetwork() {
|
||||
|
|
|
@ -61,12 +61,14 @@ struct NetworkOptions {
|
|||
Optional<bool> logClientInfo;
|
||||
Standalone<VectorRef<ClientVersionRef>> supportedVersions;
|
||||
bool slowTaskProfilingEnabled;
|
||||
boost::asio::ssl::context sslContext;
|
||||
std::string tlsPassword;
|
||||
|
||||
// The default values, TRACE_DEFAULT_ROLL_SIZE and TRACE_DEFAULT_MAX_LOGS_SIZE are located in Trace.h.
|
||||
NetworkOptions()
|
||||
: localAddress(""), clusterFile(""), traceDirectory(Optional<std::string>()),
|
||||
traceRollSize(TRACE_DEFAULT_ROLL_SIZE), traceMaxLogsSize(TRACE_DEFAULT_MAX_LOGS_SIZE), traceLogGroup("default"),
|
||||
traceFormat("xml"), slowTaskProfilingEnabled(false) {}
|
||||
traceFormat("xml"), slowTaskProfilingEnabled(false), sslContext(boost::asio::ssl::context(boost::asio::ssl::context::tlsv12)), tlsPassword("") {}
|
||||
};
|
||||
|
||||
class Database {
|
||||
|
|
|
@ -988,10 +988,12 @@ ACTOR static Future<Void> listen( TransportData* self, NetworkAddress listenAddr
|
|||
try {
|
||||
loop {
|
||||
Reference<IConnection> conn = wait( listener->accept() );
|
||||
TraceEvent("ConnectionFrom", conn->getDebugID()).suppressFor(1.0)
|
||||
.detail("FromAddress", conn->getPeerAddress())
|
||||
.detail("ListenAddress", listenAddr.toString());
|
||||
incoming.add( connectionIncoming(self, conn) );
|
||||
if(conn) {
|
||||
TraceEvent("ConnectionFrom", conn->getDebugID()).suppressFor(1.0)
|
||||
.detail("FromAddress", conn->getPeerAddress())
|
||||
.detail("ListenAddress", listenAddr.toString());
|
||||
incoming.add( connectionIncoming(self, conn) );
|
||||
}
|
||||
wait(delay(0) || delay(FLOW_KNOBS->CONNECTION_ACCEPT_DELAY, TaskPriority::WriteSocket));
|
||||
}
|
||||
} catch (Error& e) {
|
||||
|
|
|
@ -356,12 +356,6 @@ void TLSOptions::set_verify_peers( std::vector<std::string> const& verify_peers
|
|||
verify_peers_set = true;
|
||||
}
|
||||
|
||||
void TLSOptions::register_network() {
|
||||
// Simulation relies upon being able to call this multiple times, and have it override g_network
|
||||
// each time it's called.
|
||||
new TLSNetworkConnections( Reference<TLSOptions>::addRef( this ) );
|
||||
}
|
||||
|
||||
ACTOR static Future<ErrorOr<Standalone<StringRef>>> readEntireFile( std::string filename ) {
|
||||
state Reference<IAsyncFile> file = wait(IAsyncFileSystem::filesystem()->open(filename, IAsyncFile::OPEN_READONLY | IAsyncFile::OPEN_UNCACHED, 0));
|
||||
state int64_t filesize = wait(file->size());
|
||||
|
|
|
@ -85,8 +85,6 @@ struct TLSOptions : ReferenceCounted<TLSOptions> {
|
|||
void set_key_data( std::string const& key_data );
|
||||
void set_verify_peers( std::vector<std::string> const& verify_peers );
|
||||
|
||||
void register_network();
|
||||
|
||||
Reference<ITLSPolicy> get_policy(PolicyType type);
|
||||
bool enabled();
|
||||
|
||||
|
|
|
@ -117,7 +117,6 @@ static void simInitTLS(Reference<TLSOptions> tlsOptions) {
|
|||
tlsOptions->set_cert_data( certBytes );
|
||||
tlsOptions->set_key_data( certBytes );
|
||||
tlsOptions->set_verify_peers(std::vector<std::string>(1, "Check.Valid=0"));
|
||||
tlsOptions->register_network();
|
||||
}
|
||||
|
||||
ACTOR Future<Void> runBackup( Reference<ClusterConnectionFile> connFile ) {
|
||||
|
@ -247,9 +246,6 @@ ACTOR Future<ISimulator::KillType> simulatedFDBDRebooter(Reference<ClusterConnec
|
|||
//SOMEDAY: test lower memory limits, without making them too small and causing the database to stop making progress
|
||||
FlowTransport::createInstance(processClass == ProcessClass::TesterClass || runBackupAgents == AgentOnly, 1);
|
||||
Sim2FileSystem::newFileSystem();
|
||||
if (sslEnabled) {
|
||||
tlsOptions->register_network();
|
||||
}
|
||||
|
||||
vector<Future<Void>> futures;
|
||||
for (int listenPort = port; listenPort < port + listenPerProcess; ++listenPort) {
|
||||
|
|
|
@ -81,6 +81,8 @@
|
|||
#include "flow/SimpleOpt.h"
|
||||
#include "flow/actorcompiler.h" // This must be the last #include.
|
||||
|
||||
#define CERT_FILE_MAX_SIZE (5 * 1024 * 1024)
|
||||
|
||||
enum {
|
||||
OPT_CONNFILE, OPT_SEEDCONNFILE, OPT_SEEDCONNSTRING, OPT_ROLE, OPT_LISTEN, OPT_PUBLICADDR, OPT_DATAFOLDER, OPT_LOGFOLDER, OPT_PARENTPID, OPT_NEWCONSOLE,
|
||||
OPT_NOBOX, OPT_TESTFILE, OPT_RESTARTING, OPT_RESTORING, OPT_RANDOMSEED, OPT_KEY, OPT_MEMLIMIT, OPT_STORAGEMEMLIMIT, OPT_CACHEMEMLIMIT, OPT_MACHINEID,
|
||||
|
@ -961,7 +963,7 @@ int main(int argc, char* argv[]) {
|
|||
int minTesterCount = 1;
|
||||
bool testOnServers = false;
|
||||
|
||||
Reference<TLSOptions> tlsOptions = Reference<TLSOptions>( new TLSOptions );
|
||||
boost::asio::ssl::context sslContext(boost::asio::ssl::context::tlsv12);
|
||||
std::string tlsCertPath, tlsKeyPath, tlsCAPath, tlsPassword;
|
||||
std::vector<std::string> tlsVerifyPeers;
|
||||
double fileIoTimeout = 0.0;
|
||||
|
@ -1551,7 +1553,21 @@ int main(int argc, char* argv[]) {
|
|||
startNewSimulator();
|
||||
openTraceFile(NetworkAddress(), rollsize, maxLogsSize, logFolder, "trace", logGroup);
|
||||
} else {
|
||||
g_network = newNet2(useThreadPool, true);
|
||||
#ifndef TLS_DISABLED
|
||||
if ( tlsCertPath.size() ) {
|
||||
sslContext.use_certificate_chain_file(tlsCertPath);
|
||||
}
|
||||
if (tlsCAPath.size()) {
|
||||
std::string cert = readFileBytes(tlsCAPath, CERT_FILE_MAX_SIZE);
|
||||
sslContext.add_certificate_authority(boost::asio::buffer(cert.data(), cert.size()));
|
||||
}
|
||||
if (tlsKeyPath.size()) {
|
||||
sslContext.use_private_key_file(tlsKeyPath, boost::asio::ssl::context::pem);
|
||||
}
|
||||
//if ( tlsVerifyPeers.size() ) FIXME
|
||||
// tlsOptions->set_verify_peers( tlsVerifyPeers );
|
||||
#endif
|
||||
g_network = newNet2(useThreadPool, true, &sslContext, tlsPassword);
|
||||
FlowTransport::createInstance(false, 1);
|
||||
|
||||
const bool expectsPublicAddress = (role == FDBD || role == NetworkTestServer || role == Restore);
|
||||
|
@ -1565,22 +1581,7 @@ int main(int argc, char* argv[]) {
|
|||
|
||||
openTraceFile(publicAddresses.address, rollsize, maxLogsSize, logFolder, "trace", logGroup);
|
||||
|
||||
#ifndef TLS_DISABLED
|
||||
if ( tlsCertPath.size() )
|
||||
tlsOptions->set_cert_file( tlsCertPath );
|
||||
if (tlsCAPath.size())
|
||||
tlsOptions->set_ca_file(tlsCAPath);
|
||||
if (tlsKeyPath.size()) {
|
||||
if (tlsPassword.size())
|
||||
tlsOptions->set_key_password(tlsPassword);
|
||||
|
||||
tlsOptions->set_key_file(tlsKeyPath);
|
||||
}
|
||||
if ( tlsVerifyPeers.size() )
|
||||
tlsOptions->set_verify_peers( tlsVerifyPeers );
|
||||
|
||||
tlsOptions->register_network();
|
||||
#endif
|
||||
if (expectsPublicAddress) {
|
||||
for (int ii = 0; ii < (publicAddresses.secondaryAddress.present() ? 2 : 1); ++ii) {
|
||||
const NetworkAddress& publicAddress = ii==0 ? publicAddresses.address : publicAddresses.secondaryAddress.get();
|
||||
|
@ -1789,7 +1790,7 @@ int main(int argc, char* argv[]) {
|
|||
}
|
||||
}
|
||||
}
|
||||
setupAndRun( dataFolder, testFile, restarting, (isRestoring >= 1), whitelistBinPaths, tlsOptions);
|
||||
setupAndRun( dataFolder, testFile, restarting, (isRestoring >= 1), whitelistBinPaths, Reference<TLSOptions>()); //FIXME
|
||||
g_simulator.run();
|
||||
} else if (role == FDBD) {
|
||||
ASSERT( connectionFile );
|
||||
|
|
|
@ -49,6 +49,7 @@ intptr_t g_stackYieldLimit = 0;
|
|||
|
||||
using namespace boost::asio::ip;
|
||||
|
||||
typedef boost::asio::ssl::stream<boost::asio::ip::tcp::socket&> ssl_socket;
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <execinfo.h>
|
||||
|
@ -111,7 +112,7 @@ thread_local INetwork* thread_network = 0;
|
|||
class Net2 sealed : public INetwork, public INetworkConnections {
|
||||
|
||||
public:
|
||||
Net2(bool useThreadPool, bool useMetrics);
|
||||
Net2(bool useThreadPool, bool useMetrics, boost::asio::ssl::context* sslContext, std::string tlsPassword);
|
||||
void run();
|
||||
void initMetrics();
|
||||
|
||||
|
@ -154,6 +155,13 @@ public:
|
|||
//private:
|
||||
|
||||
ASIOReactor reactor;
|
||||
boost::asio::ssl::context* sslContext;
|
||||
std::string tlsPassword;
|
||||
|
||||
std::string get_password() const {
|
||||
return tlsPassword;
|
||||
}
|
||||
|
||||
INetworkConnections *network; // initially this, but can be changed
|
||||
|
||||
int64_t tsc_begin, tsc_end;
|
||||
|
@ -429,6 +437,216 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
class SSLConnection : public IConnection, ReferenceCounted<SSLConnection> {
|
||||
public:
|
||||
virtual void addref() { ReferenceCounted<SSLConnection>::addref(); }
|
||||
virtual void delref() { ReferenceCounted<SSLConnection>::delref(); }
|
||||
|
||||
virtual void close() {
|
||||
closeSocket();
|
||||
}
|
||||
|
||||
explicit SSLConnection( boost::asio::io_service& io_service, boost::asio::ssl::context& context )
|
||||
: id(nondeterministicRandom()->randomUniqueID()), socket(io_service), ssl_sock(socket, context)
|
||||
{
|
||||
}
|
||||
|
||||
// This is not part of the IConnection interface, because it is wrapped by INetwork::connect()
|
||||
ACTOR static Future<Reference<IConnection>> connect( boost::asio::io_service* ios, boost::asio::ssl::context* context, NetworkAddress addr ) {
|
||||
state std::pair<IPAddress,uint16_t> peerIP = std::make_pair(addr.ip, addr.port);
|
||||
auto iter(g_network->networkInfo.serverTLSConnectionThrottler.find(peerIP));
|
||||
if(iter != g_network->networkInfo.serverTLSConnectionThrottler.end()) {
|
||||
if (now() < iter->second.second) {
|
||||
if(iter->second.first >= FLOW_KNOBS->TLS_CLIENT_CONNECTION_THROTTLE_ATTEMPTS) {
|
||||
TraceEvent("TLSOutgoingConnectionThrottlingWarning").suppressFor(1.0).detail("PeerIP", addr);
|
||||
wait(delay(FLOW_KNOBS->CONNECTION_MONITOR_TIMEOUT));
|
||||
throw connection_failed();
|
||||
}
|
||||
} else {
|
||||
g_network->networkInfo.serverTLSConnectionThrottler.erase(peerIP);
|
||||
}
|
||||
}
|
||||
|
||||
state Reference<SSLConnection> self( new SSLConnection(*ios, *context) );
|
||||
|
||||
self->peer_address = addr;
|
||||
try {
|
||||
auto to = tcpEndpoint(addr);
|
||||
BindPromise p("N2_ConnectError", self->id);
|
||||
Future<Void> onConnected = p.getFuture();
|
||||
self->socket.async_connect( to, std::move(p) );
|
||||
|
||||
wait( onConnected );
|
||||
try {
|
||||
BindPromise p("N2_ConnectHandshakeError", self->id);
|
||||
Future<Void> onHandshook = p.getFuture();
|
||||
self->ssl_sock.async_handshake( boost::asio::ssl::stream_base::client, std::move(p) );
|
||||
wait( onHandshook );
|
||||
} catch (Error& e) {
|
||||
auto iter(g_network->networkInfo.serverTLSConnectionThrottler.find(peerIP));
|
||||
if(iter != g_network->networkInfo.serverTLSConnectionThrottler.end()) {
|
||||
iter->second.first++;
|
||||
} else {
|
||||
g_network->networkInfo.serverTLSConnectionThrottler[peerIP] = std::make_pair(0,now() + FLOW_KNOBS->TLS_CLIENT_CONNECTION_THROTTLE_TIMEOUT);
|
||||
}
|
||||
throw;
|
||||
}
|
||||
|
||||
self->init();
|
||||
return self;
|
||||
} catch (Error& e) {
|
||||
// Either the connection failed, or was cancelled by the caller
|
||||
self->closeSocket();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
// This is not part of the IConnection interface, because it is wrapped by IListener::accept()
|
||||
void accept(NetworkAddress peerAddr) {
|
||||
this->peer_address = peerAddr;
|
||||
init();
|
||||
}
|
||||
|
||||
// returns when write() can write at least one byte
|
||||
virtual Future<Void> onWritable() {
|
||||
++g_net2->countWriteProbes;
|
||||
BindPromise p("N2_WriteProbeError", id);
|
||||
auto f = p.getFuture();
|
||||
socket.async_write_some( boost::asio::null_buffers(), std::move(p) );
|
||||
return f;
|
||||
}
|
||||
|
||||
// returns when read() can read at least one byte
|
||||
virtual Future<Void> onReadable() {
|
||||
++g_net2->countReadProbes;
|
||||
BindPromise p("N2_ReadProbeError", id);
|
||||
auto f = p.getFuture();
|
||||
socket.async_read_some( boost::asio::null_buffers(), std::move(p) );
|
||||
return f;
|
||||
}
|
||||
|
||||
// Reads as many bytes as possible from the read buffer into [begin,end) and returns the number of bytes read (might be 0)
|
||||
virtual int read( uint8_t* begin, uint8_t* end ) {
|
||||
boost::system::error_code err;
|
||||
++g_net2->countReads;
|
||||
size_t toRead = end-begin;
|
||||
size_t size = ssl_sock.read_some( boost::asio::mutable_buffers_1(begin, toRead), err );
|
||||
g_net2->bytesReceived += size;
|
||||
//TraceEvent("ConnRead", this->id).detail("Bytes", size);
|
||||
if (err) {
|
||||
if (err == boost::asio::error::would_block) {
|
||||
++g_net2->countWouldBlock;
|
||||
return 0;
|
||||
}
|
||||
onReadError(err);
|
||||
throw connection_failed();
|
||||
}
|
||||
ASSERT( size ); // If the socket is closed, we expect an 'eof' error, not a zero return value
|
||||
|
||||
return size;
|
||||
}
|
||||
|
||||
// Writes as many bytes as possible from the given SendBuffer chain into the write buffer and returns the number of bytes written (might be 0)
|
||||
virtual int write( SendBuffer const* data, int limit ) {
|
||||
boost::system::error_code err;
|
||||
++g_net2->countWrites;
|
||||
|
||||
size_t sent = ssl_sock.write_some( boost::iterator_range<SendBufferIterator>(SendBufferIterator(data, limit), SendBufferIterator()), err );
|
||||
|
||||
if (err) {
|
||||
// Since there was an error, sent's value can't be used to infer that the buffer has data and the limit is positive so check explicitly.
|
||||
ASSERT(limit > 0);
|
||||
bool notEmpty = false;
|
||||
for(auto p = data; p; p = p->next)
|
||||
if(p->bytes_written - p->bytes_sent > 0) {
|
||||
notEmpty = true;
|
||||
break;
|
||||
}
|
||||
ASSERT(notEmpty);
|
||||
|
||||
if (err == boost::asio::error::would_block) {
|
||||
++g_net2->countWouldBlock;
|
||||
return 0;
|
||||
}
|
||||
onWriteError(err);
|
||||
throw connection_failed();
|
||||
}
|
||||
|
||||
ASSERT( sent ); // Make sure data was sent, and also this check will fail if the buffer chain was empty or the limit was not > 0.
|
||||
return sent;
|
||||
}
|
||||
|
||||
virtual NetworkAddress getPeerAddress() { return peer_address; }
|
||||
|
||||
virtual UID getDebugID() { return id; }
|
||||
|
||||
tcp::socket& getSocket() { return socket; }
|
||||
|
||||
ssl_socket& getSSLSocket() { return ssl_sock; }
|
||||
private:
|
||||
UID id;
|
||||
tcp::socket socket;
|
||||
ssl_socket ssl_sock;
|
||||
NetworkAddress peer_address;
|
||||
|
||||
struct SendBufferIterator {
|
||||
typedef boost::asio::const_buffer value_type;
|
||||
typedef std::forward_iterator_tag iterator_category;
|
||||
typedef size_t difference_type;
|
||||
typedef boost::asio::const_buffer* pointer;
|
||||
typedef boost::asio::const_buffer& reference;
|
||||
|
||||
SendBuffer const* p;
|
||||
int limit;
|
||||
|
||||
SendBufferIterator(SendBuffer const* p=0, int limit = std::numeric_limits<int>::max()) : p(p), limit(limit) {
|
||||
ASSERT(limit > 0);
|
||||
}
|
||||
|
||||
bool operator == (SendBufferIterator const& r) const { return p == r.p; }
|
||||
bool operator != (SendBufferIterator const& r) const { return p != r.p; }
|
||||
void operator++() {
|
||||
limit -= p->bytes_written - p->bytes_sent;
|
||||
if(limit > 0)
|
||||
p = p->next;
|
||||
else
|
||||
p = NULL;
|
||||
}
|
||||
|
||||
boost::asio::const_buffer operator*() const {
|
||||
return boost::asio::const_buffer( p->data + p->bytes_sent, std::min(limit, p->bytes_written - p->bytes_sent) );
|
||||
}
|
||||
};
|
||||
|
||||
void init() {
|
||||
// Socket settings that have to be set after connect or accept succeeds
|
||||
socket.non_blocking(true);
|
||||
socket.set_option(boost::asio::ip::tcp::no_delay(true));
|
||||
platform::setCloseOnExec(socket.native_handle());
|
||||
}
|
||||
|
||||
void closeSocket() {
|
||||
try {
|
||||
socket.cancel();
|
||||
} catch(...) {}
|
||||
try {
|
||||
socket.close();
|
||||
} catch(...) {}
|
||||
try {
|
||||
ssl_sock.shutdown();
|
||||
} catch(...) {}
|
||||
}
|
||||
|
||||
void onReadError( const boost::system::error_code& error ) {
|
||||
TraceEvent(SevWarn, "N2_ReadError", id).suppressFor(1.0).detail("Message", error.value());
|
||||
closeSocket();
|
||||
}
|
||||
void onWriteError( const boost::system::error_code& error ) {
|
||||
TraceEvent(SevWarn, "N2_WriteError", id).suppressFor(1.0).detail("Message", error.value());
|
||||
closeSocket();
|
||||
}
|
||||
};
|
||||
|
||||
class Listener : public IListener, ReferenceCounted<Listener> {
|
||||
NetworkAddress listenAddress;
|
||||
tcp::acceptor acceptor;
|
||||
|
@ -471,6 +689,77 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
class SSLListener : public IListener, ReferenceCounted<SSLListener> {
|
||||
NetworkAddress listenAddress;
|
||||
tcp::acceptor acceptor;
|
||||
boost::asio::ssl::context* context;
|
||||
|
||||
public:
|
||||
SSLListener( boost::asio::io_service& io_service, boost::asio::ssl::context* context, NetworkAddress listenAddress )
|
||||
: listenAddress(listenAddress), acceptor( io_service, tcpEndpoint( listenAddress ) ), context(context)
|
||||
{
|
||||
platform::setCloseOnExec(acceptor.native_handle());
|
||||
}
|
||||
|
||||
virtual void addref() { ReferenceCounted<SSLListener>::addref(); }
|
||||
virtual void delref() { ReferenceCounted<SSLListener>::delref(); }
|
||||
|
||||
// Returns one incoming connection when it is available
|
||||
virtual Future<Reference<IConnection>> accept() {
|
||||
return doAccept( this );
|
||||
}
|
||||
|
||||
virtual NetworkAddress getListenAddress() { return listenAddress; }
|
||||
|
||||
private:
|
||||
ACTOR static Future<Reference<IConnection>> doAccept( SSLListener* self ) {
|
||||
state Reference<SSLConnection> conn( new SSLConnection( self->acceptor.get_io_service(), *self->context) );
|
||||
state tcp::acceptor::endpoint_type peer_endpoint;
|
||||
try {
|
||||
BindPromise p("N2_AcceptError", UID());
|
||||
auto f = p.getFuture();
|
||||
self->acceptor.async_accept( conn->getSocket(), peer_endpoint, std::move(p) );
|
||||
wait( f );
|
||||
state IPAddress peer_address = peer_endpoint.address().is_v6() ? IPAddress(peer_endpoint.address().to_v6().to_bytes()) : IPAddress(peer_endpoint.address().to_v4().to_ulong());
|
||||
state std::pair<IPAddress,uint16_t> peerIP = std::make_pair(peer_address, static_cast<uint16_t>(0));
|
||||
auto iter(g_network->networkInfo.serverTLSConnectionThrottler.find(peerIP));
|
||||
if(iter != g_network->networkInfo.serverTLSConnectionThrottler.end()) {
|
||||
if (now() < iter->second.second) {
|
||||
if(iter->second.first >= FLOW_KNOBS->TLS_SERVER_CONNECTION_THROTTLE_ATTEMPTS) {
|
||||
TraceEvent("TLSIncomingConnectionThrottlingWarning").suppressFor(1.0).detail("PeerIP", peerIP.first.toString());
|
||||
wait(delay(FLOW_KNOBS->CONNECTION_MONITOR_TIMEOUT));
|
||||
throw connection_failed();
|
||||
}
|
||||
} else {
|
||||
g_network->networkInfo.serverTLSConnectionThrottler.erase(peerIP);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
BindPromise p("N2_AcceptHandshakeError", UID());
|
||||
auto f = p.getFuture();
|
||||
conn->getSSLSocket().async_handshake( boost::asio::ssl::stream_base::server, std::move(p) );
|
||||
wait( f );
|
||||
} catch (...) {
|
||||
auto iter(g_network->networkInfo.serverTLSConnectionThrottler.find(peerIP));
|
||||
if(iter != g_network->networkInfo.serverTLSConnectionThrottler.end()) {
|
||||
iter->second.first++;
|
||||
} else {
|
||||
g_network->networkInfo.serverTLSConnectionThrottler[peerIP] = std::make_pair(0,now() + FLOW_KNOBS->TLS_SERVER_CONNECTION_THROTTLE_TIMEOUT);
|
||||
}
|
||||
throw;
|
||||
}
|
||||
|
||||
conn->accept(NetworkAddress(peer_address, peer_endpoint.port(), false, true));
|
||||
|
||||
return conn;
|
||||
} catch (...) {
|
||||
conn->close();
|
||||
return Reference<IConnection>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct PromiseTask : public Task, public FastAllocated<PromiseTask> {
|
||||
Promise<Void> promise;
|
||||
PromiseTask() {}
|
||||
|
@ -482,7 +771,7 @@ struct PromiseTask : public Task, public FastAllocated<PromiseTask> {
|
|||
}
|
||||
};
|
||||
|
||||
Net2::Net2(bool useThreadPool, bool useMetrics)
|
||||
Net2::Net2(bool useThreadPool, bool useMetrics, boost::asio::ssl::context* sslContext, std::string tlsPassword)
|
||||
: useThreadPool(useThreadPool),
|
||||
network(this),
|
||||
reactor(this),
|
||||
|
@ -491,10 +780,16 @@ Net2::Net2(bool useThreadPool, bool useMetrics)
|
|||
// Until run() is called, yield() will always yield
|
||||
tsc_begin(0), tsc_end(0), taskBegin(0), currentTaskID(TaskPriority::DefaultYield),
|
||||
lastMinTaskID(TaskPriority::Zero),
|
||||
numYields(0)
|
||||
numYields(0),
|
||||
sslContext(sslContext),
|
||||
tlsPassword(tlsPassword)
|
||||
{
|
||||
TraceEvent("Net2Starting");
|
||||
|
||||
if(sslContext) {
|
||||
sslContext->set_password_callback(std::bind(&Net2::get_password, this));
|
||||
}
|
||||
|
||||
// Set the global members
|
||||
if(useMetrics) {
|
||||
setGlobal(INetwork::enTDMetrics, (flowGlobalType) &tdmetrics);
|
||||
|
@ -870,8 +1165,11 @@ THREAD_HANDLE Net2::startThread( THREAD_FUNC_RETURN (*func) (void*), void *arg )
|
|||
return ::startThread(func, arg);
|
||||
}
|
||||
|
||||
|
||||
Future< Reference<IConnection> > Net2::connect( NetworkAddress toAddr, std::string host ) {
|
||||
if ( toAddr.isTLS() ) {
|
||||
return SSLConnection::connect(&this->reactor.ios, this->sslContext, toAddr);
|
||||
}
|
||||
|
||||
return Connection::connect(&this->reactor.ios, toAddr);
|
||||
}
|
||||
|
||||
|
@ -945,6 +1243,9 @@ bool Net2::isAddressOnThisHost( NetworkAddress const& addr ) {
|
|||
|
||||
Reference<IListener> Net2::listen( NetworkAddress localAddr ) {
|
||||
try {
|
||||
if ( localAddr.isTLS() ) {
|
||||
return Reference<IListener>(new SSLListener( reactor.ios, this->sslContext, localAddr ));
|
||||
}
|
||||
return Reference<IListener>( new Listener( reactor.ios, localAddr ) );
|
||||
} catch (boost::system::system_error const& e) {
|
||||
Error x;
|
||||
|
@ -1039,9 +1340,55 @@ void ASIOReactor::wake() {
|
|||
|
||||
} // namespace net2
|
||||
|
||||
INetwork* newNet2(bool useThreadPool, bool useMetrics) {
|
||||
bool verify_certificate_cb(bool preverified, boost::asio::ssl::verify_context& ctx)
|
||||
{
|
||||
/*
|
||||
std::cout << "Function : " << __func__ << " ----------------- Line : " << __LINE__ << std::endl;
|
||||
int8_t subject_name[256];
|
||||
X509_STORE_CTX *cts = ctx.native_handle();
|
||||
int32_t length = 0;
|
||||
X509* cert = X509_STORE_CTX_get_current_cert(ctx.native_handle());
|
||||
std::cout << "CTX ERROR : " << cts->error << std::endl;
|
||||
|
||||
int32_t depth = X509_STORE_CTX_get_error_depth(cts);
|
||||
std::cout << "CTX DEPTH : " << depth << std::endl;
|
||||
|
||||
switch (cts->error)
|
||||
{
|
||||
case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT:
|
||||
printf("X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT\n");
|
||||
break;
|
||||
case X509_V_ERR_CERT_NOT_YET_VALID:
|
||||
case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD:
|
||||
printf("Certificate not yet valid!!\n");
|
||||
break;
|
||||
case X509_V_ERR_CERT_HAS_EXPIRED:
|
||||
case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD:
|
||||
printf("Certificate expired..\n");
|
||||
break;
|
||||
case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
|
||||
printf("Self signed certificate in chain!!!\n");
|
||||
preverified = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
const int32_t name_length = 256;
|
||||
X509_NAME_oneline(X509_get_subject_name(cert), reinterpret_cast<char*>(subject_name), name_length);
|
||||
printf("Verifying %s\n", subject_name);
|
||||
printf("Verification status : %d\n", preverified);
|
||||
|
||||
std::cout << "Function : " << __func__ << " ----------------- Line : " << __LINE__ << std::endl;
|
||||
*/
|
||||
return true;
|
||||
}
|
||||
|
||||
INetwork* newNet2(bool useThreadPool, bool useMetrics, boost::asio::ssl::context* sslContext, std::string tlsPassword) {
|
||||
try {
|
||||
N2::g_net2 = new N2::Net2(useThreadPool, useMetrics);
|
||||
sslContext->set_options(boost::asio::ssl::context::default_workarounds);
|
||||
sslContext->set_verify_mode(boost::asio::ssl::context::verify_peer | boost::asio::ssl::verify_fail_if_no_peer_cert);
|
||||
sslContext->set_verify_callback(boost::bind(&verify_certificate_cb, _1, _2));
|
||||
N2::g_net2 = new N2::Net2(useThreadPool, useMetrics, sslContext, tlsPassword);
|
||||
}
|
||||
catch(boost::system::system_error e) {
|
||||
TraceEvent("Net2InitError").detail("Message", e.what());
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <stdint.h>
|
||||
#include <variant>
|
||||
#include "boost/asio.hpp"
|
||||
#include "boost/asio/ssl.hpp"
|
||||
#include "flow/serialize.h"
|
||||
#include "flow/IRandom.h"
|
||||
|
||||
|
@ -390,7 +391,7 @@ typedef NetworkAddressList (*NetworkAddressesFuncPtr)();
|
|||
|
||||
class INetwork;
|
||||
extern INetwork* g_network;
|
||||
extern INetwork* newNet2(bool useThreadPool = false, bool useMetrics = false);
|
||||
extern INetwork* newNet2(bool useThreadPool = false, bool useMetrics = false, boost::asio::ssl::context* sslContext = nullptr, std::string tlsPassword = "");
|
||||
|
||||
class INetwork {
|
||||
public:
|
||||
|
|
Loading…
Reference in New Issue