Update sidecar to match latest changes in operator

This commit is contained in:
Johannes M. Scheuermann 2021-07-07 14:26:51 +01:00
parent cc72c5e23c
commit 33928904ec
1 changed files with 132 additions and 48 deletions

View File

@ -1,10 +1,10 @@
#! /usr/bin/env python3 #! /usr/bin/python
# entrypoint.py # entrypoint.py
# #
# This source file is part of the FoundationDB open source project # This source file is part of the FoundationDB open source project
# #
# Copyright 2018-2019 Apple Inc. and the FoundationDB project authors # Copyright 2018-2021 Apple Inc. and the FoundationDB project authors
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -21,10 +21,11 @@
import argparse import argparse
import hashlib import hashlib
import http.server import ipaddress
import logging import logging
import json import json
import os import os
import re
import shutil import shutil
import socket import socket
import ssl import ssl
@ -32,8 +33,11 @@ import stat
import time import time
import traceback import traceback
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
from http.server import HTTPServer, ThreadingHTTPServer, BaseHTTPRequestHandler
from watchdog.observers import Observer from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler from watchdog.events import FileSystemEventHandler
@ -53,12 +57,10 @@ class Config(object):
), ),
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument("--bind-address", help="IP and port to bind on")
"--bind-address", help="IP and port to bind on", default="0.0.0.0:8080"
)
parser.add_argument( parser.add_argument(
"--tls", "--tls",
help=("This flag enables TLS for incoming " "connections"), help=("This flag enables TLS for incoming connections"),
action="store_true", action="store_true",
) )
parser.add_argument( parser.add_argument(
@ -103,7 +105,7 @@ class Config(object):
) )
parser.add_argument( parser.add_argument(
"--input-dir", "--input-dir",
help=("The directory containing the input files " "the config map."), help=("The directory containing the input files the config map."),
default="/var/input-files", default="/var/input-files",
) )
parser.add_argument( parser.add_argument(
@ -125,28 +127,33 @@ class Config(object):
) )
parser.add_argument( parser.add_argument(
"--copy-file", "--copy-file",
help=("A file to copy from the config map to the " "output directory."), help=("A file to copy from the config map to the output directory."),
action="append", action="append",
) )
parser.add_argument( parser.add_argument(
"--copy-binary", "--copy-binary",
help=("A binary to copy from the to the output" "directory."), help=("A binary to copy from the to the output directory."),
action="append", action="append",
) )
parser.add_argument( parser.add_argument(
"--copy-library", "--copy-library",
help=( help=("A version of the client library to copy to the output directory."),
"A version of the client library to copy " "to the output directory."
),
action="append", action="append",
) )
parser.add_argument( parser.add_argument(
"--input-monitor-conf", "--input-monitor-conf",
help=("The name of a monitor conf template in the " "input files"), help=("The name of a monitor conf template in the input files"),
) )
parser.add_argument( parser.add_argument(
"--main-container-version", "--main-container-version",
help=("The version of the main foundationdb " "container in the pod"), help=("The version of the main foundationdb container in the pod"),
)
parser.add_argument(
"--public-ip-family",
help=(
"Tells the sidecar to treat the public IP as a comma-separated "
"list, and use the first entry in the specified IP family"
),
) )
parser.add_argument( parser.add_argument(
"--main-container-conf-dir", "--main-container-conf-dir",
@ -216,6 +223,7 @@ class Config(object):
"FDB_MACHINE_ID", "FDB_MACHINE_ID",
"FDB_ZONE_ID", "FDB_ZONE_ID",
"FDB_INSTANCE_ID", "FDB_INSTANCE_ID",
"FDB_POD_IP",
]: ]:
self.substitutions[key] = os.getenv(key, "") self.substitutions[key] = os.getenv(key, "")
@ -225,13 +233,12 @@ class Config(object):
if self.substitutions["FDB_ZONE_ID"] == "": if self.substitutions["FDB_ZONE_ID"] == "":
self.substitutions["FDB_ZONE_ID"] = self.substitutions["FDB_MACHINE_ID"] self.substitutions["FDB_ZONE_ID"] = self.substitutions["FDB_MACHINE_ID"]
if self.substitutions["FDB_PUBLIC_IP"] == "": if self.substitutions["FDB_PUBLIC_IP"] == "":
address_info = socket.getaddrinfo( # As long as the public IP is not set fallback to the
self.substitutions["FDB_MACHINE_ID"], # Pod IP address.
4500, pod_ip = os.getenv("FDB_POD_IP")
family=socket.AddressFamily.AF_INET, if pod_ip is none:
) pod_ip = socket.gethostbyname(socket.gethostname())
if len(address_info) > 0: self.substitutions["FDB_PUBLIC_IP"] = pod_ip
self.substitutions["FDB_PUBLIC_IP"] = address_info[0][4][0]
if self.main_container_version == self.primary_version: if self.main_container_version == self.primary_version:
self.substitutions["BINARY_DIR"] = "/usr/bin" self.substitutions["BINARY_DIR"] = "/usr/bin"
@ -290,6 +297,21 @@ class Config(object):
if os.getenv("COPY_ONCE", "0") == "1": if os.getenv("COPY_ONCE", "0") == "1":
self.init_mode = True self.init_mode = True
if args.public_ip_family:
version = int(args.public_ip_family)
self.substitutions["FDB_PUBLIC_IP"] = Config.extract_desired_ip(
version, self.substitutions["FDB_PUBLIC_IP"]
)
self.substitutions["FDB_POD_IP"] = Config.extract_desired_ip(
version, self.substitutions["FDB_POD_IP"]
)
if not self.bind_address:
if self.substitutions["FDB_POD_IP"] != "":
self.bind_address = self.substitutions["FDB_POD_IP"] + ":8080"
else:
self.bind_address = self.substitutions["FDB_PUBLIC_IP"] + ":8080"
@classmethod @classmethod
def shared(cls): def shared(cls):
if cls.shared_config: if cls.shared_config:
@ -305,8 +327,26 @@ class Config(object):
and self.minor_version[1] >= target_version[1] and self.minor_version[1] >= target_version[1]
) )
@classmethod
def extract_desired_ip(cls, version, string):
if string == "":
return string
class Server(http.server.BaseHTTPRequestHandler): ips = string.split(",")
matching_ips = [ip for ip in ips if ipaddress.ip_address(ip).version == version]
if len(matching_ips) == 0:
raise Exception(f"Failed to find IPv{version} entry in {ips}")
ip = matching_ips[0]
if version == 6:
ip = f"[{ip}]"
return ip
class ThreadingHTTPServerV6(ThreadingHTTPServer):
address_family = socket.AF_INET6
class Server(BaseHTTPRequestHandler):
ssl_context = None ssl_context = None
@classmethod @classmethod
@ -315,13 +355,20 @@ class Server(http.server.BaseHTTPRequestHandler):
This method starts the server. This method starts the server.
""" """
config = Config.shared() config = Config.shared()
(address, port) = config.bind_address.split(":") colon_index = config.bind_address.rindex(":")
log.info("Listening on %s:%s" % (address, port)) port_index = colon_index + 1
httpd = http.server.HTTPServer((address, int(port)), cls) address = config.bind_address[:colon_index]
port = config.bind_address[port_index:]
log.info(f"Listening on {address}:{port}")
if address.startswith("[") and address.endswith("]"):
server = ThreadingHTTPServerV6((address[1:-1], int(port)), cls)
else:
server = ThreadingHTTPServer((address, int(port)), cls)
if config.enable_tls: if config.enable_tls:
context = Server.load_ssl_context() context = Server.load_ssl_context()
httpd.socket = context.wrap_socket(httpd.socket, server_side=True) server.socket = context.wrap_socket(server.socket, server_side=True)
observer = Observer() observer = Observer()
event_handler = CertificateEventHandler() event_handler = CertificateEventHandler()
for path in set( for path in set(
@ -333,7 +380,7 @@ class Server(http.server.BaseHTTPRequestHandler):
observer.schedule(event_handler, path) observer.schedule(event_handler, path)
observer.start() observer.start()
httpd.serve_forever() server.serve_forever()
@classmethod @classmethod
def load_ssl_context(cls): def load_ssl_context(cls):
@ -341,7 +388,7 @@ class Server(http.server.BaseHTTPRequestHandler):
if not cls.ssl_context: if not cls.ssl_context:
cls.ssl_context = ssl.create_default_context(cafile=config.ca_file) cls.ssl_context = ssl.create_default_context(cafile=config.ca_file)
cls.ssl_context.check_hostname = False cls.ssl_context.check_hostname = False
cls.ssl_context.verify_mode = ssl.CERT_REQUIRED cls.ssl_context.verify_mode = ssl.CERT_OPTIONAL
cls.ssl_context.load_cert_chain(config.certificate_file, config.key_file) cls.ssl_context.load_cert_chain(config.certificate_file, config.key_file)
return cls.ssl_context return cls.ssl_context
@ -359,13 +406,21 @@ class Server(http.server.BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
self.wfile.write(response) self.wfile.write(response)
def check_request_cert(self): def check_request_cert(self, path):
config = Config.shared() config = Config.shared()
approved = not config.enable_tls or self.check_cert(
if path == "/ready":
return True
if not config.enable_tls:
return True
approved = self.check_cert(
self.connection.getpeercert(), config.peer_verification_rules self.connection.getpeercert(), config.peer_verification_rules
) )
if not approved: if not approved:
self.send_error(401, "Client certificate was not approved") self.send_error(401, "Client certificate was not approved")
return approved return approved
def check_cert(self, cert, rules): def check_cert(self, cert, rules):
@ -375,6 +430,9 @@ class Server(http.server.BaseHTTPRequestHandler):
If there is any problem with the certificate, this will return a string If there is any problem with the certificate, this will return a string
describing the error. describing the error.
""" """
if cert is None:
return False
if not rules: if not rules:
return True return True
@ -456,7 +514,7 @@ class Server(http.server.BaseHTTPRequestHandler):
This method executes a GET request. This method executes a GET request.
""" """
try: try:
if not self.check_request_cert(): if not self.check_request_cert(self.path):
return return
if self.path.startswith("/check_hash/"): if self.path.startswith("/check_hash/"):
try: try:
@ -483,7 +541,7 @@ class Server(http.server.BaseHTTPRequestHandler):
This method executes a POST request. This method executes a POST request.
""" """
try: try:
if not self.check_request_cert(): if not self.check_request_cert(self.path):
return return
if self.path == "/copy_files": if self.path == "/copy_files":
self.send_text(copy_files()) self.send_text(copy_files())
@ -516,7 +574,19 @@ class Server(http.server.BaseHTTPRequestHandler):
class CertificateEventHandler(FileSystemEventHandler): class CertificateEventHandler(FileSystemEventHandler):
def on_any_event(self, event): def on_any_event(self, event):
log.info("Detected change to certificates") if event.is_directory:
return None
if event.event_type not in ["created", "modified"]:
return None
# We ignore all old files
if event.src_path.endswith(".old"):
return None
log.info(
f"Detected change to certificates path: {event.src_path}, type: {event.event_type }"
)
time.sleep(10) time.sleep(10)
log.info("Reloading certificates") log.info("Reloading certificates")
Server.load_ssl_context() Server.load_ssl_context()
@ -536,10 +606,13 @@ def copy_files():
path = os.path.join(config.input_dir, filename) path = os.path.join(config.input_dir, filename)
if not os.path.isfile(path) or os.path.getsize(path) == 0: if not os.path.isfile(path) or os.path.getsize(path) == 0:
raise Exception("No contents for file %s" % path) raise Exception("No contents for file %s" % path)
for filename in config.copy_files: for filename in config.copy_files:
tmp_file = os.path.join(config.output_dir, f"{filename}.tmp") tmp_file = tempfile.NamedTemporaryFile(
shutil.copy(os.path.join(config.input_dir, filename), tmp_file) mode="w+b", dir=config.output_dir, delete=False
os.replace(tmp_file, os.path.join(config.output_dir, filename)) )
shutil.copy(os.path.join(config.input_dir, filename), tmp_file.name)
os.replace(tmp_file.name, os.path.join(config.output_dir, filename))
return "OK" return "OK"
@ -554,9 +627,13 @@ def copy_binaries():
) )
if not target_path.exists(): if not target_path.exists():
target_path.parent.mkdir(parents=True, exist_ok=True) target_path.parent.mkdir(parents=True, exist_ok=True)
tmp_file = f"{target_path}.tmp" tmp_file = tempfile.NamedTemporaryFile(
shutil.copy(path, tmp_file) mode="w+b",
os.replace(tmp_file, target_path) dir=target_path.parent,
delete=False,
)
shutil.copy(path, tmp_file.name)
os.replace(tmp_file.name, target_path)
target_path.chmod(0o744) target_path.chmod(0o744)
return "OK" return "OK"
@ -573,9 +650,11 @@ def copy_libraries():
) )
if not target_path.exists(): if not target_path.exists():
target_path.parent.mkdir(parents=True, exist_ok=True) target_path.parent.mkdir(parents=True, exist_ok=True)
tmp_file = f"{target_path}.tmp" tmp_file = tempfile.NamedTemporaryFile(
shutil.copy(path, tmp_file) mode="w+b", dir=target_path.parent, delete=False
os.replace(tmp_file, target_path) )
shutil.copy(path, tmp_file.name)
os.replace(tmp_file.name, target_path)
return "OK" return "OK"
@ -591,13 +670,16 @@ def copy_monitor_conf():
"$" + variable, config.substitutions[variable] "$" + variable, config.substitutions[variable]
) )
tmp_file = os.path.join(config.output_dir, "fdbmonitor.conf.tmp") tmp_file = tempfile.NamedTemporaryFile(
mode="w+b", dir=config.output_dir, delete=False
)
target_file = os.path.join(config.output_dir, "fdbmonitor.conf") target_file = os.path.join(config.output_dir, "fdbmonitor.conf")
with open(tmp_file, "w") as output_conf_file: with open(tmp_file.name, "w") as output_conf_file:
output_conf_file.write(monitor_conf) output_conf_file.write(monitor_conf)
os.replace(tmp_file, target_file) os.replace(tmp_file.name, target_file)
return "OK" return "OK"
@ -629,5 +711,7 @@ if __name__ == "__main__":
copy_libraries() copy_libraries()
copy_monitor_conf() copy_monitor_conf()
if not Config.shared().init_mode: if Config.shared().init_mode:
Server.start() sys.exit(0)
Server.start()