foundationdb/fdbrpc/TLSConnection.actor.cpp

328 lines
9.8 KiB
C++

/*
* TLSConnection.actor.cpp
*
* This source file is part of the FoundationDB open source project
*
* Copyright 2013-2018 Apple Inc. and the FoundationDB project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "flow/actorcompiler.h"
#include "flow/network.h"
#include "TLSConnection.h"
#include "ITLSPlugin.h"
#include "LoadPlugin.h"
#include "Platform.h"
// Must not throw an exception from this function!
static int send_func(void* ctx, const uint8_t* buf, int len) {
TLSConnection* conn = (TLSConnection*)ctx;
try {
SendBuffer sb;
sb.bytes_sent = 0;
sb.bytes_written = len;
sb.data = buf;
sb.next = 0;
int w = conn->conn->write( &sb );
return w;
} catch ( Error& e ) {
TraceEvent("TLSConnectionSendError", conn->getDebugID()).error(e);
return -1;
} catch ( ... ) {
TraceEvent("TLSConnectionSendError", conn->getDebugID()).error( unknown_error() );
return -1;
}
}
// Must not throw an exception from this function!
static int recv_func(void* ctx, uint8_t* buf, int len) {
TLSConnection* conn = (TLSConnection*)ctx;
try {
int r = conn->conn->read( buf, buf + len );
return r;
} catch ( Error& e ) {
TraceEvent("TLSConnectionRecvError", conn->getDebugID()).error(e);
return -1;
} catch ( ... ) {
TraceEvent("TLSConnectionRecvError", conn->getDebugID()).error( unknown_error() );
return -1;
}
}
ACTOR static Future<Void> handshake( TLSConnection* self ) {
loop {
int r = self->session->handshake();
if ( r == ITLSSession::SUCCESS ) break;
if ( r == ITLSSession::FAILED ) {
TraceEvent("TLSConnectionHandshakeError", self->getDebugID());
throw connection_failed();
}
ASSERT( r == ITLSSession::WANT_WRITE || r == ITLSSession::WANT_READ );
Void _ = wait( r == ITLSSession::WANT_WRITE ? self->conn->onWritable() : self->conn->onReadable() );
}
TraceEvent("TLSConnectionHandshakeSuccessful", self->getDebugID())
.detail("Peer", self->getPeerAddress());
return Void();
}
TLSConnection::TLSConnection( Reference<IConnection> const& conn, Reference<ITLSPolicy> const& policy, bool is_client ) : conn(conn), write_wants(0), read_wants(0), uid(conn->getDebugID()) {
session = Reference<ITLSSession>( policy->create_session(is_client, send_func, this, recv_func, this, (void*)&uid) );
if ( !session ) {
// If session is NULL, we're trusting policy->create_session
// to have used its provided logging function to have logged
// the error
throw internal_error();
}
handshook = handshake(this);
}
Future<Void> TLSConnection::onWritable() {
if ( !handshook.isReady() )
return handshook;
return
write_wants == ITLSSession::WANT_READ ? conn->onReadable() :
write_wants == ITLSSession::WANT_WRITE ? conn->onWritable() :
Void();
}
Future<Void> TLSConnection::onReadable() {
if ( !handshook.isReady() )
return handshook;
return
read_wants == ITLSSession::WANT_READ ? conn->onReadable() :
read_wants == ITLSSession::WANT_WRITE ? conn->onWritable() :
Void();
}
int TLSConnection::read( uint8_t* begin, uint8_t* end ) {
if ( !handshook.isReady() ) return 0;
handshook.get();
read_wants = 0;
int r = session->read( begin, end - begin );
if ( r > 0 )
return r;
if ( r == ITLSSession::FAILED ) throw connection_failed();
ASSERT( r == ITLSSession::WANT_WRITE || r == ITLSSession::WANT_READ );
read_wants = r;
return 0;
}
int TLSConnection::write( SendBuffer const* buffer, int limit ) {
ASSERT(limit > 0);
if ( !handshook.isReady() ) return 0;
handshook.get();
write_wants = 0;
int toSend = std::min(limit, buffer->bytes_written - buffer->bytes_sent);
ASSERT(toSend);
int w = session->write( buffer->data + buffer->bytes_sent, toSend );
if ( w > 0 )
return w;
if ( w == ITLSSession::FAILED ) throw connection_failed();
ASSERT( w == ITLSSession::WANT_WRITE || w == ITLSSession::WANT_READ );
write_wants = w;
return 0;
}
ACTOR Future<Reference<IConnection>> wrap( Reference<ITLSPolicy> policy, bool is_client, Future<Reference<IConnection>> c ) {
Reference<IConnection> conn = wait(c);
return Reference<IConnection>(new TLSConnection( conn, policy, is_client ));
}
Future<Reference<IConnection>> TLSListener::accept() {
return wrap( policy, false, listener->accept() );
}
TLSNetworkConnections::TLSNetworkConnections( Reference<TLSOptions> options ) : options(options) {
network = INetworkConnections::net();
g_network->setGlobal(INetwork::enumGlobal::enNetworkConnections, (flowGlobalType) this);
}
Future<Reference<IConnection>> TLSNetworkConnections::connect( NetworkAddress toAddr ) {
if ( toAddr.isTLS() ) {
NetworkAddress clearAddr( toAddr.ip, toAddr.port, toAddr.isPublic(), false );
TraceEvent("TLSConnectionConnecting").detail("ToAddr", toAddr);
return wrap( options->get_policy(), true, network->connect( clearAddr ) );
}
return network->connect( toAddr );
}
Future<std::vector<NetworkAddress>> TLSNetworkConnections::resolveTCPEndpoint( std::string host, std::string service) {
return network->resolveTCPEndpoint( host, service );
}
Reference<IListener> TLSNetworkConnections::listen( NetworkAddress localAddr ) {
if ( localAddr.isTLS() ) {
NetworkAddress clearAddr( localAddr.ip, localAddr.port, localAddr.isPublic(), false );
TraceEvent("TLSConnectionListening").detail("OnAddr", localAddr);
return Reference<IListener>(new TLSListener( options->get_policy(), network->listen( clearAddr ) ));
}
return network->listen( localAddr );
}
// 5MB for loading files into memory
#define CERT_FILE_MAX_SIZE (5 * 1024 * 1024)
void TLSOptions::set_plugin_name_or_path( std::string const& plugin_name_or_path ) {
if ( plugin )
throw invalid_option();
init_plugin( plugin_name_or_path );
}
void TLSOptions::set_cert_file( std::string const& cert_file ) {
try {
TraceEvent("TLSConnectionSettingCertFile").detail("CertFilePath", cert_file);
set_cert_data( readFileBytes( cert_file, CERT_FILE_MAX_SIZE ) );
} catch ( Error& ) {
TraceEvent(SevError, "TLSOptionsSetCertFileError").detail("Filename", cert_file);
throw;
}
}
void TLSOptions::set_cert_data( std::string const& cert_data ) {
if ( !policy )
init_plugin();
TraceEvent("TLSConnectionSettingCertData").detail("CertDataSize", cert_data.size());
if ( !policy->set_cert_data( (const uint8_t*)&cert_data[0], cert_data.size() ) )
throw tls_error();
certs_set = true;
}
void TLSOptions::set_key_file( std::string const& key_file ) {
try {
TraceEvent("TLSConnectionSettingKeyFile").detail("KeyFilePath", key_file);
set_key_data( readFileBytes( key_file, CERT_FILE_MAX_SIZE ) );
} catch ( Error& ) {
TraceEvent(SevError, "TLSOptionsSetKeyFileError").detail("Filename", key_file);
throw;
}
}
void TLSOptions::set_key_data( std::string const& key_data ) {
if ( !policy )
init_plugin();
TraceEvent("TLSConnectionSettingKeyData").detail("KeyDataSize", key_data.size());
if ( !policy->set_key_data( (const uint8_t*)&key_data[0], key_data.size() ) )
throw tls_error();
key_set = true;
}
void TLSOptions::set_verify_peers( std::string const& verify_peers ) {
if ( !policy )
init_plugin();
TraceEvent("TLSConnectionSettingVerifyPeers").detail("Value", verify_peers);
if ( !policy->set_verify_peers( (const uint8_t*)&verify_peers[0], verify_peers.size() ) )
throw tls_error();
verify_peers_set = true;
}
void TLSOptions::register_network() {
new TLSNetworkConnections( Reference<TLSOptions>::addRef( this ) );
}
const char *defaultCertFileName = "fdb.pem";
Reference<ITLSPolicy> TLSOptions::get_policy() {
if ( !certs_set ) {
std::string certFile;
if ( !platform::getEnvironmentVar( "FDB_TLS_CERTIFICATE_FILE", certFile ) )
certFile = fileExists(defaultCertFileName) ? defaultCertFileName : joinPath(platform::getDefaultConfigPath(), defaultCertFileName);
set_cert_file( certFile );
}
if ( !key_set ) {
std::string keyFile;
if ( !platform::getEnvironmentVar( "FDB_TLS_KEY_FILE", keyFile ) )
keyFile = fileExists(defaultCertFileName) ? defaultCertFileName : joinPath(platform::getDefaultConfigPath(), defaultCertFileName);
set_key_file( keyFile );
}
if( !verify_peers_set ) {
std::string verifyPeerString;
if ( platform::getEnvironmentVar( "FDB_TLS_VERIFY_PEERS", verifyPeerString ) )
set_verify_peers( verifyPeerString );
}
return policy;
}
static void TLSConnectionLogFunc( const char* event, void* uid_ptr, int is_error, ... ) {
UID uid;
if ( uid_ptr )
uid = *(UID*)uid_ptr;
Severity s = SevInfo;
if ( is_error )
s = SevError;
auto t = TraceEvent( s, event, uid );
va_list ap;
char* field;
va_start( ap, is_error );
while ( (field = va_arg( ap, char* )) ) {
t.detail( field, va_arg( ap, char* ) );
}
va_end( ap );
}
void TLSOptions::init_plugin( std::string const& plugin_path ) {
std::string path;
if ( plugin_path.length() ) {
path = plugin_path;
} else {
if ( !platform::getEnvironmentVar( "FDB_TLS_PLUGIN", path ) )
// FIXME: should there be other fallbacks?
path = platform::getDefaultPluginPath("FDBGnuTLS");
}
TraceEvent("TLSConnectionLoadingPlugin").detail("PluginPath", path);
plugin = loadPlugin<ITLSPlugin>( path.c_str() );
if ( !plugin ) {
// FIXME: allow?
TraceEvent(SevError, "TLSConnectionPluginInitError").detail("Plugin", path).GetLastError();
throw tls_error();
}
policy = Reference<ITLSPolicy>( plugin->create_policy( TLSConnectionLogFunc ) );
if ( !policy ) {
// Hopefully create_policy logged something with the log func
TraceEvent(SevError, "TLSConnectionCreatePolicyError");
throw tls_error();
}
}