Support SSL Key Rotation in HTTP Server (#13495)

This commit is contained in:
Keyun Tong 2025-02-22 05:17:44 -08:00 committed by GitHub
parent 2382ad29d1
commit 8db1b9d0a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 173 additions and 2 deletions

View File

@ -20,7 +20,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines == 0.1.11
lark == 1.2.2
lark == 1.2.2
xgrammar == 0.1.11; platform_machine == "x86_64"
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
@ -37,3 +37,4 @@ einops # Required for Qwen2-VL.
compressed-tensors == 0.9.2 # required for compressed-tensors
depyf==0.18.0 # required for profiling and debugging with compilation config
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
watchfiles # required for http server to monitor the updates of TLS files

View File

@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import tempfile
from pathlib import Path
from ssl import SSLContext
import pytest
from vllm.entrypoints.ssl import SSLCertRefresher
class MockSSLContext(SSLContext):
def __init__(self):
self.load_cert_chain_count = 0
self.load_ca_count = 0
def load_cert_chain(
self,
certfile,
keyfile=None,
password=None,
):
self.load_cert_chain_count += 1
def load_verify_locations(
self,
cafile=None,
capath=None,
cadata=None,
):
self.load_ca_count += 1
def create_file() -> str:
with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f:
return f.name
def touch_file(path: str) -> None:
Path(path).touch()
@pytest.mark.asyncio
async def test_ssl_refresher():
ssl_context = MockSSLContext()
key_path = create_file()
cert_path = create_file()
ca_path = create_file()
ssl_refresher = SSLCertRefresher(ssl_context, key_path, cert_path, ca_path)
await asyncio.sleep(1)
assert ssl_context.load_cert_chain_count == 0
assert ssl_context.load_ca_count == 0
touch_file(key_path)
await asyncio.sleep(1)
assert ssl_context.load_cert_chain_count == 1
assert ssl_context.load_ca_count == 0
touch_file(cert_path)
touch_file(ca_path)
await asyncio.sleep(1)
assert ssl_context.load_cert_chain_count == 2
assert ssl_context.load_ca_count == 1
ssl_refresher.stop()
touch_file(cert_path)
touch_file(ca_path)
await asyncio.sleep(1)
assert ssl_context.load_cert_chain_count == 2
assert ssl_context.load_ca_count == 1

View File

@ -128,6 +128,7 @@ async def run_server(args: Namespace,
shutdown_task = await serve_http(
app,
sock=None,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.log_level,
@ -152,6 +153,11 @@ if __name__ == "__main__":
type=str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--enable-ssl-refresh",
action="store_true",
default=False,
help="Refresh SSL Context when SSL certificate files change")
parser.add_argument(
"--ssl-cert-reqs",
type=int,

View File

@ -12,13 +12,16 @@ from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
logger = init_logger(__name__)
async def serve_http(app: FastAPI, sock: Optional[socket.socket],
async def serve_http(app: FastAPI,
sock: Optional[socket.socket],
enable_ssl_refresh: bool = False,
**uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
config = uvicorn.Config(app, **uvicorn_kwargs)
config.load()
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server)
@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
server_task = loop.create_task(
server.serve(sockets=[sock] if sock else None))
ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher(
ssl_context=config.ssl,
key_path=config.ssl_keyfile,
cert_path=config.ssl_certfile,
ca_path=config.ssl_ca_certs)
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()
async def dummy_shutdown() -> None:
pass

View File

@ -960,6 +960,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
shutdown_task = await serve_http(
app,
sock=sock,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,

View File

@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=nullable_str,
default=None,
help="The CA certificates file.")
parser.add_argument(
"--enable-ssl-refresh",
action="store_true",
default=False,
help="Refresh SSL Context when SSL certificate files change")
parser.add_argument(
"--ssl-cert-reqs",
type=int,

74
vllm/entrypoints/ssl.py Normal file
View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from ssl import SSLContext
from typing import Callable, Optional
from watchfiles import Change, awatch
from vllm.logger import init_logger
logger = init_logger(__name__)
class SSLCertRefresher:
"""A class that monitors SSL certificate files and
reloads them when they change.
"""
def __init__(self,
ssl_context: SSLContext,
key_path: Optional[str] = None,
cert_path: Optional[str] = None,
ca_path: Optional[str] = None) -> None:
self.ssl = ssl_context
self.key_path = key_path
self.cert_path = cert_path
self.ca_path = ca_path
# Setup certification chain watcher
def update_ssl_cert_chain(change: Change, file_path: str) -> None:
logger.info("Reloading SSL certificate chain")
assert self.key_path and self.cert_path
self.ssl.load_cert_chain(self.cert_path, self.key_path)
self.watch_ssl_cert_task = None
if self.key_path and self.cert_path:
self.watch_ssl_cert_task = asyncio.create_task(
self._watch_files([self.key_path, self.cert_path],
update_ssl_cert_chain))
# Setup CA files watcher
def update_ssl_ca(change: Change, file_path: str) -> None:
logger.info("Reloading SSL CA certificates")
assert self.ca_path
self.ssl.load_verify_locations(self.ca_path)
self.watch_ssl_ca_task = None
if self.ca_path:
self.watch_ssl_ca_task = asyncio.create_task(
self._watch_files([self.ca_path], update_ssl_ca))
async def _watch_files(self, paths, fun: Callable[[Change, str],
None]) -> None:
"""Watch multiple file paths asynchronously."""
logger.info("SSLCertRefresher monitors files: %s", paths)
async for changes in awatch(*paths):
try:
for change, file_path in changes:
logger.info("File change detected: %s - %s", change.name,
file_path)
fun(change, file_path)
except Exception as e:
logger.error(
"SSLCertRefresher failed taking action on file change. "
"Error: %s", e)
def stop(self) -> None:
"""Stop watching files."""
if self.watch_ssl_cert_task:
self.watch_ssl_cert_task.cancel()
self.watch_ssl_cert_task = None
if self.watch_ssl_ca_task:
self.watch_ssl_ca_task.cancel()
self.watch_ssl_ca_task = None