Accept non-mTLS clients as untrusted
This commit is contained in:
parent
83e94781b4
commit
1dbb13a899
|
@ -0,0 +1,347 @@
|
|||
/*
|
||||
* TLSTest.cpp
|
||||
*
|
||||
* This source file is part of the FoundationDB open source project
|
||||
*
|
||||
* Copyright 2013-2022 Apple Inc. and the FoundationDB project authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <fmt/format.h>
|
||||
#include <unistd.h>
|
||||
#include <string_view>
|
||||
#include <signal.h>
|
||||
#include <sys/wait.h>
|
||||
#include "flow/Arena.h"
|
||||
#include "flow/MkCert.h"
|
||||
#include "flow/ScopeExit.h"
|
||||
#include "flow/TLSConfig.actor.h"
|
||||
#include "fdbrpc/fdbrpc.h"
|
||||
#include "fdbrpc/FlowTransport.h"
|
||||
#include "flow/actorcompiler.h" // This must be the last #include.
|
||||
|
||||
std::FILE* outp = stdout;
|
||||
|
||||
template <class... Args>
|
||||
void log(Args&&... args) {
|
||||
auto buf = fmt::memory_buffer{};
|
||||
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
|
||||
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void logc(Args&&... args) {
|
||||
auto buf = fmt::memory_buffer{};
|
||||
fmt::format_to(std::back_inserter(buf), "[CLIENT] ");
|
||||
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
|
||||
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void logs(Args&&... args) {
|
||||
auto buf = fmt::memory_buffer{};
|
||||
fmt::format_to(std::back_inserter(buf), "[SERVER] ");
|
||||
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
|
||||
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
void logm(Args&&... args) {
|
||||
auto buf = fmt::memory_buffer{};
|
||||
fmt::format_to(std::back_inserter(buf), "[ MAIN ] ");
|
||||
fmt::format_to(std::back_inserter(buf), std::forward<Args>(args)...);
|
||||
fmt::print(outp, "{}\n", std::string_view(buf.data(), buf.size()));
|
||||
}
|
||||
|
||||
struct TLSCreds {
|
||||
std::string certBytes;
|
||||
std::string keyBytes;
|
||||
std::string caBytes;
|
||||
};
|
||||
|
||||
TLSCreds makeCreds(int chainLen, mkcert::ESide side) {
|
||||
if (chainLen == 0)
|
||||
return {};
|
||||
auto arena = Arena();
|
||||
auto ret = TLSCreds{};
|
||||
auto specs = mkcert::makeCertChainSpec(arena, std::labs(chainLen), side);
|
||||
if (chainLen < 0) {
|
||||
specs[0].offsetNotBefore = -60l * 60 * 24 * 365;
|
||||
specs[0].offsetNotAfter = -10l; // cert that expired 10 seconds ago
|
||||
}
|
||||
auto chain = mkcert::makeCertChain(arena, specs, {} /* create root CA cert from spec*/);
|
||||
if (chain.size() == 1) {
|
||||
ret.certBytes = concatCertChain(arena, chain).toString();
|
||||
} else {
|
||||
auto nonRootChain = chain;
|
||||
nonRootChain.pop_back();
|
||||
ret.certBytes = concatCertChain(arena, nonRootChain).toString();
|
||||
}
|
||||
ret.caBytes = chain.back().certPem.toString();
|
||||
ret.keyBytes = chain.front().privateKeyPem.toString();
|
||||
return ret;
|
||||
}
|
||||
|
||||
enum class Result : int {
|
||||
TRUSTED = 0,
|
||||
UNTRUSTED,
|
||||
ERROR,
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fmt::formatter<Result> {
|
||||
constexpr auto parse(format_parse_context& ctx) -> decltype(ctx.begin()) { return ctx.begin(); }
|
||||
|
||||
template <class FormatContext>
|
||||
auto format(const Result& r, FormatContext& ctx) -> decltype(ctx.out()) {
|
||||
if (r == Result::TRUSTED)
|
||||
return fmt::format_to(ctx.out(), "TRUSTED");
|
||||
else if (r == Result::UNTRUSTED)
|
||||
return fmt::format_to(ctx.out(), "UNTRUSTED");
|
||||
else
|
||||
return fmt::format_to(ctx.out(), "ERROR");
|
||||
}
|
||||
};
|
||||
|
||||
ACTOR template <class T>
|
||||
Future<T> stopNetworkAfter(Future<T> what) {
|
||||
T t = wait(what);
|
||||
g_network->stop();
|
||||
return t;
|
||||
}
|
||||
|
||||
// Reflective struct containing information about the requester from a server PoV
|
||||
struct SessionInfo {
|
||||
constexpr static FileIdentifier file_identifier = 1578312;
|
||||
bool isPeerTrusted = false;
|
||||
NetworkAddress peerAddress;
|
||||
|
||||
template <class Ar>
|
||||
void serialize(Ar& ar) {
|
||||
serializer(ar, isPeerTrusted, peerAddress);
|
||||
}
|
||||
};
|
||||
|
||||
struct SessionProbeRequest {
|
||||
constexpr static FileIdentifier file_identifier = 1559713;
|
||||
ReplyPromise<SessionInfo> reply{ PeerCompatibilityPolicy{ RequirePeer::AtLeast,
|
||||
ProtocolVersion::withStableInterfaces() } };
|
||||
|
||||
bool verify() const { return true; }
|
||||
|
||||
template <class Ar>
|
||||
void serialize(Ar& ar) {
|
||||
serializer(ar, reply);
|
||||
}
|
||||
};
|
||||
|
||||
struct SessionProbeReceiver final : NetworkMessageReceiver {
|
||||
SessionProbeReceiver() {}
|
||||
void receive(ArenaObjectReader& reader) override {
|
||||
SessionProbeRequest req;
|
||||
reader.deserialize(req);
|
||||
SessionInfo res;
|
||||
res.isPeerTrusted = FlowTransport::transport().currentDeliveryPeerIsTrusted();
|
||||
res.peerAddress = FlowTransport::transport().currentDeliveryPeerAddress();
|
||||
req.reply.send(res);
|
||||
}
|
||||
PeerCompatibilityPolicy peerCompatibilityPolicy() const override {
|
||||
return PeerCompatibilityPolicy{ RequirePeer::AtLeast, ProtocolVersion::withStableInterfaces() };
|
||||
}
|
||||
bool isPublic() const override { return true; }
|
||||
};
|
||||
|
||||
Future<Void> runServer(Future<Void> listenFuture, const Endpoint& endpoint, int addrPipe, int completionPipe) {
|
||||
auto realAddr = FlowTransport::transport().getLocalAddresses().address;
|
||||
logs("Listening at {}", realAddr.toString());
|
||||
logs("Endpoint token is {}", endpoint.token.toString());
|
||||
// below writes/reads would block, but this is good enough for a test.
|
||||
if (sizeof(realAddr) != ::write(addrPipe, &realAddr, sizeof(realAddr))) {
|
||||
logs("Failed to write server addr to pipe: {}", strerror(errno));
|
||||
return Void();
|
||||
}
|
||||
if (sizeof(endpoint.token) != ::write(addrPipe, &endpoint.token, sizeof(endpoint.token))) {
|
||||
logs("Failed to write server endpoint to pipe: {}", strerror(errno));
|
||||
return Void();
|
||||
}
|
||||
auto done = false;
|
||||
if (sizeof(done) != ::read(completionPipe, &done, sizeof(done))) {
|
||||
logs("Failed to read completion flag from pipe: {}", strerror(errno));
|
||||
return Void();
|
||||
}
|
||||
return Void();
|
||||
}
|
||||
|
||||
ACTOR Future<Void> waitAndPrintResponse(Future<SessionInfo> response, Result* rc) {
|
||||
try {
|
||||
SessionInfo info = wait(response);
|
||||
logc("Probe response: trusted={} peerAddress={}", info.isPeerTrusted, info.peerAddress.toString());
|
||||
*rc = info.isPeerTrusted ? Result::TRUSTED : Result::UNTRUSTED;
|
||||
} catch (Error& err) {
|
||||
logc("Error: {}", err.what());
|
||||
*rc = Result::ERROR;
|
||||
}
|
||||
return Void();
|
||||
}
|
||||
|
||||
template <bool IsServer>
|
||||
int runHost(TLSCreds creds, int addrPipe, int completionPipe, Result expect) {
|
||||
auto tlsConfig = TLSConfig(IsServer ? TLSEndpointType::SERVER : TLSEndpointType::CLIENT);
|
||||
tlsConfig.setCertificateBytes(creds.certBytes);
|
||||
tlsConfig.setCABytes(creds.caBytes);
|
||||
tlsConfig.setKeyBytes(creds.keyBytes);
|
||||
g_network = newNet2(tlsConfig);
|
||||
openTraceFile(
|
||||
NetworkAddress(), 10 << 20, 10 << 20, ".", IsServer ? "authz_tls_unittest_server" : "authz_tls_unittest_client");
|
||||
FlowTransport::createInstance(!IsServer, 1, WLTOKEN_RESERVED_COUNT);
|
||||
auto& transport = FlowTransport::transport();
|
||||
if constexpr (IsServer) {
|
||||
auto addr = NetworkAddress::parse("127.0.0.1:0:tls");
|
||||
auto thread = std::thread([]() {
|
||||
g_network->run();
|
||||
flushTraceFileVoid();
|
||||
});
|
||||
auto endpoint = Endpoint();
|
||||
auto receiver = SessionProbeReceiver();
|
||||
transport.addEndpoint(endpoint, &receiver, TaskPriority::ReadSocket);
|
||||
runServer(transport.bind(addr, addr), endpoint, addrPipe, completionPipe);
|
||||
auto cleanupGuard = ScopeExit([&thread]() {
|
||||
g_network->stop();
|
||||
thread.join();
|
||||
});
|
||||
} else {
|
||||
auto dest = Endpoint();
|
||||
auto& serverAddr = dest.addresses.address;
|
||||
if (sizeof(serverAddr) != ::read(addrPipe, &serverAddr, sizeof(serverAddr))) {
|
||||
logc("Failed to read server addr from pipe: {}", strerror(errno));
|
||||
return 1;
|
||||
}
|
||||
auto& token = dest.token;
|
||||
if (sizeof(token) != ::read(addrPipe, &token, sizeof(token))) {
|
||||
logc("Failed to read server endpoint token from pipe: {}", strerror(errno));
|
||||
return 2;
|
||||
}
|
||||
logc("Server address is {}", serverAddr.toString());
|
||||
logc("Server endpoint token is {}", token.toString());
|
||||
auto sessionProbeReq = SessionProbeRequest{};
|
||||
transport.sendUnreliable(SerializeSource(sessionProbeReq), dest, true /*openConnection*/);
|
||||
logc("Request is sent");
|
||||
auto probeResponse = sessionProbeReq.reply.getFuture();
|
||||
auto result = Result::TRUSTED;
|
||||
auto timeout = delay(5);
|
||||
auto complete = waitAndPrintResponse(probeResponse, &result);
|
||||
auto f = stopNetworkAfter(complete || timeout);
|
||||
auto rc = 0;
|
||||
g_network->run();
|
||||
if (!complete.isReady()) {
|
||||
logc("Error: Probe request timed out");
|
||||
rc = 3;
|
||||
}
|
||||
auto done = true;
|
||||
if (sizeof(done) != ::write(completionPipe, &done, sizeof(done))) {
|
||||
logc("Failed to signal server to terminate: {}", strerror(errno));
|
||||
rc = 4;
|
||||
}
|
||||
if (rc == 0) {
|
||||
if (expect != result) {
|
||||
logc("Test failed: expected {}, got {}", expect, result);
|
||||
rc = 5;
|
||||
} else {
|
||||
logc("Response OK: got {} as expected", result);
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int runTlsTest(int serverChainLen, int clientChainLen) {
|
||||
log("==== BEGIN TESTCASE ====");
|
||||
auto expect = Result::ERROR;
|
||||
if (serverChainLen > 0) {
|
||||
if (clientChainLen > 0)
|
||||
expect = Result::TRUSTED;
|
||||
else if (clientChainLen == 0)
|
||||
expect = Result::UNTRUSTED;
|
||||
}
|
||||
log("Cert chain length: server={} client={}", serverChainLen, clientChainLen);
|
||||
auto arena = Arena();
|
||||
auto serverCreds = makeCreds(serverChainLen, mkcert::ESide::Server);
|
||||
auto clientCreds = makeCreds(clientChainLen, mkcert::ESide::Client);
|
||||
// make server and client trust each other
|
||||
std::swap(serverCreds.caBytes, clientCreds.caBytes);
|
||||
auto clientPid = pid_t{};
|
||||
auto serverPid = pid_t{};
|
||||
int addrPipe[2];
|
||||
int completionPipe[2];
|
||||
if (::pipe(addrPipe) || ::pipe(completionPipe)) {
|
||||
logm("Pipe open failed: {}", strerror(errno));
|
||||
return 1;
|
||||
}
|
||||
auto pipeCleanup = ScopeExit([&addrPipe, &completionPipe]() {
|
||||
::close(addrPipe[0]);
|
||||
::close(addrPipe[1]);
|
||||
::close(completionPipe[0]);
|
||||
::close(completionPipe[1]);
|
||||
});
|
||||
serverPid = fork();
|
||||
if (serverPid == 0) {
|
||||
_exit(runHost<true>(std::move(serverCreds), addrPipe[1], completionPipe[0], expect));
|
||||
}
|
||||
clientPid = fork();
|
||||
if (clientPid == 0) {
|
||||
_exit(runHost<false>(std::move(clientCreds), addrPipe[0], completionPipe[1], expect));
|
||||
}
|
||||
auto pid = pid_t{};
|
||||
auto status = int{};
|
||||
pid = waitpid(clientPid, &status, 0);
|
||||
auto ok = true;
|
||||
if (pid < 0) {
|
||||
logm("waitpid() for client failed with {}", strerror(errno));
|
||||
ok = false;
|
||||
} else {
|
||||
if (status != 0) {
|
||||
logm("Client error: rc={}", status);
|
||||
ok = false;
|
||||
} else {
|
||||
logm("Client OK");
|
||||
}
|
||||
}
|
||||
pid = waitpid(serverPid, &status, 0);
|
||||
if (pid < 0) {
|
||||
logm("waitpid() for server failed with {}", strerror(errno));
|
||||
ok = false;
|
||||
} else {
|
||||
if (status != 0) {
|
||||
logm("Server error: rc={}", status);
|
||||
ok = false;
|
||||
} else {
|
||||
logm("Server OK");
|
||||
}
|
||||
}
|
||||
log(ok ? "OK" : "FAILED");
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
std::pair<int, int> inputs[] = { { 3, 2 }, { 4, 0 }, { 1, 3 }, { 1, 0 }, { 2, 0 }, { 3, 3 }, { 3, 0 } };
|
||||
for (auto input : inputs) {
|
||||
auto [serverChainLen, clientChainLen] = input;
|
||||
if (auto rc = runTlsTest(serverChainLen, clientChainLen))
|
||||
return rc;
|
||||
}
|
||||
return 0;
|
||||
}
|
|
@ -80,3 +80,10 @@ target_compile_definitions(fdbrpc_sampling PRIVATE -DENABLE_SAMPLING)
|
|||
if(WIN32)
|
||||
add_dependencies(fdbrpc_sampling_actors fdbrpc_actors)
|
||||
endif()
|
||||
|
||||
if(UNIX)
|
||||
add_flow_target(EXECUTABLE NAME authz_tls_unittest SRCS AuthzTlsTest.actor.cpp)
|
||||
target_link_libraries(authz_tls_unittest PRIVATE flow fdbrpc fmt::fmt)
|
||||
add_test(NAME authorization_tls_unittest
|
||||
COMMAND $<TARGET_FILE:authz_tls_unittest>)
|
||||
endif()
|
||||
|
|
|
@ -958,7 +958,7 @@ void Peer::onIncomingConnection(Reference<Peer> self, Reference<IConnection> con
|
|||
.detail("FromAddr", conn->getPeerAddress())
|
||||
.detail("CanonicalAddr", destination)
|
||||
.detail("IsPublic", destination.isPublic())
|
||||
.detail("Trusted", self->transport->allowList(conn->getPeerAddress().ip));
|
||||
.detail("Trusted", self->transport->allowList(conn->getPeerAddress().ip) && conn->hasTrustedPeer());
|
||||
|
||||
connect.cancel();
|
||||
prependConnectPacket();
|
||||
|
@ -1257,7 +1257,7 @@ ACTOR static Future<Void> connectionReader(TransportData* transport,
|
|||
state bool incompatiblePeerCounted = false;
|
||||
state NetworkAddress peerAddress;
|
||||
state ProtocolVersion peerProtocolVersion;
|
||||
state bool trusted = transport->allowList(conn->getPeerAddress().ip);
|
||||
state bool trusted = transport->allowList(conn->getPeerAddress().ip) && conn->hasTrustedPeer();
|
||||
peerAddress = conn->getPeerAddress();
|
||||
|
||||
if (!peer) {
|
||||
|
|
|
@ -125,6 +125,10 @@ NetworkAddress SimExternalConnection::getPeerAddress() const {
|
|||
}
|
||||
}
|
||||
|
||||
bool SimExternalConnection::hasTrustedPeer() const {
|
||||
return true;
|
||||
}
|
||||
|
||||
UID SimExternalConnection::getDebugID() const {
|
||||
return dbgid;
|
||||
}
|
||||
|
|
|
@ -47,6 +47,7 @@ public:
|
|||
int read(uint8_t* begin, uint8_t* end) override;
|
||||
int write(SendBuffer const* buffer, int limit) override;
|
||||
NetworkAddress getPeerAddress() const override;
|
||||
bool hasTrustedPeer() const override;
|
||||
UID getDebugID() const override;
|
||||
boost::asio::ip::tcp::socket& getSocket() override { return socket; }
|
||||
static Future<std::vector<NetworkAddress>> resolveTCPEndpoint(const std::string& host,
|
||||
|
|
|
@ -208,7 +208,7 @@ SimClogging g_clogging;
|
|||
|
||||
struct Sim2Conn final : IConnection, ReferenceCounted<Sim2Conn> {
|
||||
Sim2Conn(ISimulator::ProcessInfo* process)
|
||||
: opened(false), closedByCaller(false), stableConnection(false), process(process),
|
||||
: opened(false), closedByCaller(false), stableConnection(false), trustedPeer(true), process(process),
|
||||
dbgid(deterministicRandom()->randomUniqueID()), stopReceive(Never()) {
|
||||
pipes = sender(this) && receiver(this);
|
||||
}
|
||||
|
@ -259,6 +259,8 @@ struct Sim2Conn final : IConnection, ReferenceCounted<Sim2Conn> {
|
|||
|
||||
bool isPeerGone() const { return !peer || peerProcess->failed; }
|
||||
|
||||
bool hasTrustedPeer() const override { return trustedPeer; }
|
||||
|
||||
bool isStableConnection() const override { return stableConnection; }
|
||||
|
||||
void peerClosed() {
|
||||
|
@ -327,7 +329,7 @@ struct Sim2Conn final : IConnection, ReferenceCounted<Sim2Conn> {
|
|||
|
||||
boost::asio::ip::tcp::socket& getSocket() override { throw operation_failed(); }
|
||||
|
||||
bool opened, closedByCaller, stableConnection;
|
||||
bool opened, closedByCaller, stableConnection, trustedPeer;
|
||||
|
||||
private:
|
||||
ISimulator::ProcessInfo *process, *peerProcess;
|
||||
|
|
|
@ -85,8 +85,3 @@ endif()
|
|||
|
||||
add_executable(mkcert MkCertCli.cpp)
|
||||
target_link_libraries(mkcert PUBLIC flow)
|
||||
|
||||
add_executable(mtls_unittest TLSTest.cpp)
|
||||
target_link_libraries(mtls_unittest PUBLIC flow)
|
||||
add_test(NAME mutual_tls_unittest
|
||||
COMMAND $<TARGET_FILE:mtls_unittest>)
|
||||
|
|
|
@ -236,6 +236,7 @@ public:
|
|||
int sslHandshakerThreadsStarted;
|
||||
int sslPoolHandshakesInProgress;
|
||||
TLSConfig tlsConfig;
|
||||
Reference<TLSPolicy> activeTlsPolicy;
|
||||
Future<Void> backgroundCertRefresh;
|
||||
ETLSInitState tlsInitializedState;
|
||||
|
||||
|
@ -505,6 +506,8 @@ public:
|
|||
|
||||
NetworkAddress getPeerAddress() const override { return peer_address; }
|
||||
|
||||
bool hasTrustedPeer() const override { return true; }
|
||||
|
||||
UID getDebugID() const override { return id; }
|
||||
|
||||
tcp::socket& getSocket() override { return socket; }
|
||||
|
@ -837,7 +840,7 @@ public:
|
|||
explicit SSLConnection(boost::asio::io_service& io_service,
|
||||
Reference<ReferencedObject<boost::asio::ssl::context>> context)
|
||||
: id(nondeterministicRandom()->randomUniqueID()), socket(io_service), ssl_sock(socket, context->mutate()),
|
||||
sslContext(context) {}
|
||||
sslContext(context), has_trusted_peer(false) {}
|
||||
|
||||
explicit SSLConnection(Reference<ReferencedObject<boost::asio::ssl::context>> context, tcp::socket* existingSocket)
|
||||
: id(nondeterministicRandom()->randomUniqueID()), socket(std::move(*existingSocket)),
|
||||
|
@ -898,6 +901,9 @@ public:
|
|||
|
||||
try {
|
||||
Future<Void> onHandshook;
|
||||
ConfigureSSLStream(N2::g_net2->activeTlsPolicy, self->ssl_sock, [this](bool verifyOk) {
|
||||
self->has_trusted_peer = verifyOk;
|
||||
});
|
||||
|
||||
// If the background handshakers are not all busy, use one
|
||||
if (N2::g_net2->sslPoolHandshakesInProgress < N2::g_net2->sslHandshakerThreadsStarted) {
|
||||
|
@ -973,6 +979,10 @@ public:
|
|||
|
||||
try {
|
||||
Future<Void> onHandshook;
|
||||
ConfigureSSLStream(N2::g_net2->activeTlsPolicy, self->ssl_sock, [this](bool verifyOk) {
|
||||
self->has_trusted_peer = verifyOk;
|
||||
});
|
||||
|
||||
// If the background handshakers are not all busy, use one
|
||||
if (N2::g_net2->sslPoolHandshakesInProgress < N2::g_net2->sslHandshakerThreadsStarted) {
|
||||
holder = Hold(&N2::g_net2->sslPoolHandshakesInProgress);
|
||||
|
@ -1106,6 +1116,10 @@ public:
|
|||
|
||||
NetworkAddress getPeerAddress() const override { return peer_address; }
|
||||
|
||||
bool hasTrustedPeer() const override {
|
||||
return has_trusted_peer;
|
||||
}
|
||||
|
||||
UID getDebugID() const override { return id; }
|
||||
|
||||
tcp::socket& getSocket() override { return socket; }
|
||||
|
@ -1118,6 +1132,7 @@ private:
|
|||
ssl_socket ssl_sock;
|
||||
NetworkAddress peer_address;
|
||||
Reference<ReferencedObject<boost::asio::ssl::context>> sslContext;
|
||||
bool has_trusted_peer;
|
||||
|
||||
void init() {
|
||||
// Socket settings that have to be set after connect or accept succeeds
|
||||
|
@ -1163,6 +1178,16 @@ public:
|
|||
NetworkAddress listenAddress)
|
||||
: io_service(io_service), listenAddress(listenAddress), acceptor(io_service, tcpEndpoint(listenAddress)),
|
||||
contextVar(contextVar) {
|
||||
// when port 0 is passed in, a random port will be opened
|
||||
// set listenAddress as the address with the actual port opened instead of port 0
|
||||
if (listenAddress.port == 0) {
|
||||
this->listenAddress = NetworkAddress::parse(acceptor.local_endpoint()
|
||||
.address()
|
||||
.to_string()
|
||||
.append(":")
|
||||
.append(std::to_string(acceptor.local_endpoint().port()))
|
||||
.append(listenAddress.isTLS() ? ":tls" : ""));
|
||||
}
|
||||
platform::setCloseOnExec(acceptor.native_handle());
|
||||
}
|
||||
|
||||
|
@ -1276,7 +1301,8 @@ ACTOR static Future<Void> watchFileForChanges(std::string filename, AsyncTrigger
|
|||
ACTOR static Future<Void> reloadCertificatesOnChange(
|
||||
TLSConfig config,
|
||||
std::function<void()> onPolicyFailure,
|
||||
AsyncVar<Reference<ReferencedObject<boost::asio::ssl::context>>>* contextVar) {
|
||||
AsyncVar<Reference<ReferencedObject<boost::asio::ssl::context>>>* contextVar,
|
||||
Reference<TLSPolicy>* policy) {
|
||||
if (FLOW_KNOBS->TLS_CERT_REFRESH_DELAY_SECONDS <= 0) {
|
||||
return Void();
|
||||
}
|
||||
|
@ -1300,7 +1326,8 @@ ACTOR static Future<Void> reloadCertificatesOnChange(
|
|||
try {
|
||||
LoadedTLSConfig loaded = wait(config.loadAsync());
|
||||
boost::asio::ssl::context context(boost::asio::ssl::context::tls);
|
||||
ConfigureSSLContext(loaded, &context, onPolicyFailure);
|
||||
ConfigureSSLContext(loaded, context);
|
||||
*policy = makeReference<TLSPolicy>(loaded, onPolicyFailure);
|
||||
TraceEvent(SevInfo, "TLSCertificateRefreshSucceeded").log();
|
||||
mismatches = 0;
|
||||
contextVar->set(ReferencedObject<boost::asio::ssl::context>::from(std::move(context)));
|
||||
|
@ -1332,12 +1359,15 @@ void Net2::initTLS(ETLSInitState targetState) {
|
|||
.detail("KeyPath", tlsConfig.getKeyPathSync())
|
||||
.detail("HasPassword", !loaded.getPassword().empty())
|
||||
.detail("VerifyPeers", boost::algorithm::join(loaded.getVerifyPeers(), "|"));
|
||||
ConfigureSSLContext(tlsConfig.loadSync(), &newContext, onPolicyFailure);
|
||||
auto loadedTlsConfig = tlsConfig.loadSync();
|
||||
ConfigureSSLContext(loadedTlsConfig, newContext);
|
||||
activeTlsPolicy = makeReference<TLSPolicy>(loadedTlsConfig, onPolicyFailure);
|
||||
sslContextVar.set(ReferencedObject<boost::asio::ssl::context>::from(std::move(newContext)));
|
||||
} catch (Error& e) {
|
||||
TraceEvent("Net2TLSInitError").error(e);
|
||||
}
|
||||
backgroundCertRefresh = reloadCertificatesOnChange(tlsConfig, onPolicyFailure, &sslContextVar);
|
||||
backgroundCertRefresh =
|
||||
reloadCertificatesOnChange(tlsConfig, onPolicyFailure, &sslContextVar, &activeTlsPolicy);
|
||||
}
|
||||
|
||||
// If a TLS connection is actually going to be used then start background threads if configured
|
||||
|
|
|
@ -81,7 +81,7 @@ void LoadedTLSConfig::print(FILE* fp) {
|
|||
int num_certs = 0;
|
||||
boost::asio::ssl::context context(boost::asio::ssl::context::tls);
|
||||
try {
|
||||
ConfigureSSLContext(*this, &context);
|
||||
ConfigureSSLContext(*this, context);
|
||||
} catch (Error& e) {
|
||||
fprintf(fp, "There was an error in loading the certificate chain.\n");
|
||||
throw;
|
||||
|
@ -109,51 +109,58 @@ void LoadedTLSConfig::print(FILE* fp) {
|
|||
X509_STORE_CTX_free(store_ctx);
|
||||
}
|
||||
|
||||
void ConfigureSSLContext(const LoadedTLSConfig& loaded,
|
||||
boost::asio::ssl::context* context,
|
||||
std::function<void()> onPolicyFailure) {
|
||||
void ConfigureSSLContext(const LoadedTLSConfig& loaded, boost::asio::ssl::context& context) {
|
||||
try {
|
||||
context->set_options(boost::asio::ssl::context::default_workarounds);
|
||||
context->set_verify_mode(boost::asio::ssl::context::verify_peer |
|
||||
boost::asio::ssl::verify_fail_if_no_peer_cert);
|
||||
context.set_options(boost::asio::ssl::context::default_workarounds);
|
||||
auto verifyFailIfNoPeerCert = boost::asio::ssl::verify_fail_if_no_peer_cert;
|
||||
// Servers get to accept connections without peer certs as "untrusted" clients
|
||||
if (loaded.getEndpointType() == TLSEndpointType::SERVER)
|
||||
verifyFailIfNoPeerCert = 0;
|
||||
context.set_verify_mode(boost::asio::ssl::context::verify_peer | verifyFailIfNoPeerCert);
|
||||
|
||||
if (loaded.isTLSEnabled()) {
|
||||
auto tlsPolicy = makeReference<TLSPolicy>(loaded.getEndpointType());
|
||||
tlsPolicy->set_verify_peers({ loaded.getVerifyPeers() });
|
||||
|
||||
context->set_verify_callback(
|
||||
[policy = tlsPolicy, onPolicyFailure](bool preverified, boost::asio::ssl::verify_context& ctx) {
|
||||
bool success = policy->verify_peer(preverified, ctx.native_handle());
|
||||
if (!success) {
|
||||
onPolicyFailure();
|
||||
}
|
||||
return success;
|
||||
});
|
||||
} else {
|
||||
// Insecurely always except if TLS is not enabled.
|
||||
context->set_verify_callback([](bool, boost::asio::ssl::verify_context&) { return true; });
|
||||
}
|
||||
|
||||
context->set_password_callback([password = loaded.getPassword()](
|
||||
size_t, boost::asio::ssl::context::password_purpose) { return password; });
|
||||
context.set_password_callback([password = loaded.getPassword()](
|
||||
size_t, boost::asio::ssl::context::password_purpose) { return password; });
|
||||
|
||||
const std::string& CABytes = loaded.getCABytes();
|
||||
if (CABytes.size()) {
|
||||
context->add_certificate_authority(boost::asio::buffer(CABytes.data(), CABytes.size()));
|
||||
context.add_certificate_authority(boost::asio::buffer(CABytes.data(), CABytes.size()));
|
||||
}
|
||||
|
||||
const std::string& keyBytes = loaded.getKeyBytes();
|
||||
if (keyBytes.size()) {
|
||||
context->use_private_key(boost::asio::buffer(keyBytes.data(), keyBytes.size()),
|
||||
boost::asio::ssl::context::pem);
|
||||
context.use_private_key(boost::asio::buffer(keyBytes.data(), keyBytes.size()),
|
||||
boost::asio::ssl::context::pem);
|
||||
}
|
||||
|
||||
const std::string& certBytes = loaded.getCertificateBytes();
|
||||
if (certBytes.size()) {
|
||||
context->use_certificate_chain(boost::asio::buffer(certBytes.data(), certBytes.size()));
|
||||
context.use_certificate_chain(boost::asio::buffer(certBytes.data(), certBytes.size()));
|
||||
}
|
||||
} catch (boost::system::system_error& e) {
|
||||
TraceEvent("TLSConfigureError")
|
||||
TraceEvent("TLSContextConfigureError")
|
||||
.detail("What", e.what())
|
||||
.detail("Value", e.code().value())
|
||||
.detail("WhichMeans", TLSPolicy::ErrorString(e.code()));
|
||||
throw tls_error();
|
||||
}
|
||||
}
|
||||
|
||||
void ConfigureSSLStream(Reference<TLSPolicy> policy,
|
||||
boost::asio::ssl::stream<boost::asio::ip::tcp::socket&>& stream,
|
||||
std::function<void(bool)> callback) {
|
||||
try {
|
||||
stream.set_verify_callback([policy, callback](bool preverified, boost::asio::ssl::verify_context& ctx) {
|
||||
bool success = policy->verify_peer(preverified, ctx.native_handle());
|
||||
if (!success) {
|
||||
if (policy->on_failure)
|
||||
policy->on_failure();
|
||||
}
|
||||
if (callback)
|
||||
callback(success);
|
||||
return success;
|
||||
});
|
||||
} catch (boost::system::system_error& e) {
|
||||
TraceEvent("TLSStreamConfigureError")
|
||||
.detail("What", e.what())
|
||||
.detail("Value", e.code().value())
|
||||
.detail("WhichMeans", TLSPolicy::ErrorString(e.code()));
|
||||
|
@ -261,6 +268,11 @@ LoadedTLSConfig TLSConfig::loadSync() const {
|
|||
return loaded;
|
||||
}
|
||||
|
||||
TLSPolicy::TLSPolicy(const LoadedTLSConfig& loaded, std::function<void()> on_failure)
|
||||
: rules(), on_failure(std::move(on_failure)), is_client(loaded.getEndpointType() == TLSEndpointType::CLIENT) {
|
||||
set_verify_peers(loaded.getVerifyPeers());
|
||||
}
|
||||
|
||||
// And now do the same thing, but async...
|
||||
|
||||
ACTOR static Future<Void> readEntireFile(std::string filename, std::string* destination) {
|
||||
|
|
|
@ -33,6 +33,8 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <boost/system/system_error.hpp>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/ssl.hpp>
|
||||
#include "flow/FastRef.h"
|
||||
#include "flow/Knobs.h"
|
||||
#include "flow/flow.h"
|
||||
|
@ -201,21 +203,23 @@ private:
|
|||
TLSEndpointType endpointType = TLSEndpointType::UNSET;
|
||||
};
|
||||
|
||||
namespace boost {
|
||||
namespace asio {
|
||||
namespace ssl {
|
||||
struct context;
|
||||
}
|
||||
} // namespace asio
|
||||
} // namespace boost
|
||||
void ConfigureSSLContext(
|
||||
const LoadedTLSConfig& loaded,
|
||||
boost::asio::ssl::context* context,
|
||||
std::function<void()> onPolicyFailure = []() {});
|
||||
class TLSPolicy;
|
||||
|
||||
void ConfigureSSLContext(const LoadedTLSConfig& loaded, boost::asio::ssl::context& context);
|
||||
|
||||
// Set up SSL for stream object based on policy.
|
||||
// Optionally arm a callback that gets called with verify-outcome of each cert in peer certificate chain:
|
||||
// e.g. for peer with a valid, trusted length-3 certificate chain (root CA, intermediate CA, and server certs),
|
||||
// callback(true) will be called 3 times.
|
||||
void ConfigureSSLStream(Reference<TLSPolicy> policy,
|
||||
boost::asio::ssl::stream<boost::asio::ip::tcp::socket&>& stream,
|
||||
std::function<void(bool)> callback);
|
||||
|
||||
class TLSPolicy : ReferenceCounted<TLSPolicy> {
|
||||
void set_verify_peers(std::vector<std::string> verify_peers);
|
||||
|
||||
public:
|
||||
TLSPolicy(TLSEndpointType client) : is_client(client == TLSEndpointType::CLIENT) {}
|
||||
TLSPolicy(const LoadedTLSConfig& loaded, std::function<void()> on_failure);
|
||||
virtual ~TLSPolicy();
|
||||
|
||||
virtual void addref() { ReferenceCounted<TLSPolicy>::addref(); }
|
||||
|
@ -223,7 +227,6 @@ public:
|
|||
|
||||
static std::string ErrorString(boost::system::error_code e);
|
||||
|
||||
void set_verify_peers(std::vector<std::string> verify_peers);
|
||||
bool verify_peer(bool preverified, X509_STORE_CTX* store_ctx);
|
||||
|
||||
std::string toString() const;
|
||||
|
@ -242,6 +245,7 @@ public:
|
|||
};
|
||||
|
||||
std::vector<Rule> rules;
|
||||
std::function<void()> on_failure;
|
||||
bool is_client;
|
||||
};
|
||||
|
||||
|
|
|
@ -467,6 +467,11 @@ public:
|
|||
// this may not be an address we can connect to!
|
||||
virtual NetworkAddress getPeerAddress() const = 0;
|
||||
|
||||
// Returns whether the peer is trusted.
|
||||
// For TLS-enabled connections, this is true if the peer has presented a valid chain of certificates trusted by the
|
||||
// local endpoint. For non-TLS connections this is always true for any valid open connection.
|
||||
virtual bool hasTrustedPeer() const = 0;
|
||||
|
||||
virtual UID getDebugID() const = 0;
|
||||
|
||||
// At present, implemented by Sim2Conn where we want to disable bits flip for connections between parent process and
|
||||
|
|
Loading…
Reference in New Issue