mirror of https://github.com/vllm-project/vllm
Support SSL Key Rotation in HTTP Server (#13495)
This commit is contained in:
parent
2382ad29d1
commit
8db1b9d0a1
|
@ -20,7 +20,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
|
|||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.9, < 0.11
|
||||
outlines == 0.1.11
|
||||
lark == 1.2.2
|
||||
lark == 1.2.2
|
||||
xgrammar == 0.1.11; platform_machine == "x86_64"
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||
|
@ -37,3 +37,4 @@ einops # Required for Qwen2-VL.
|
|||
compressed-tensors == 0.9.2 # required for compressed-tensors
|
||||
depyf==0.18.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
watchfiles # required for http server to monitor the updates of TLS files
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
import asyncio
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from ssl import SSLContext
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||
|
||||
|
||||
class MockSSLContext(SSLContext):
|
||||
|
||||
def __init__(self):
|
||||
self.load_cert_chain_count = 0
|
||||
self.load_ca_count = 0
|
||||
|
||||
def load_cert_chain(
|
||||
self,
|
||||
certfile,
|
||||
keyfile=None,
|
||||
password=None,
|
||||
):
|
||||
self.load_cert_chain_count += 1
|
||||
|
||||
def load_verify_locations(
|
||||
self,
|
||||
cafile=None,
|
||||
capath=None,
|
||||
cadata=None,
|
||||
):
|
||||
self.load_ca_count += 1
|
||||
|
||||
|
||||
def create_file() -> str:
|
||||
with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f:
|
||||
return f.name
|
||||
|
||||
|
||||
def touch_file(path: str) -> None:
|
||||
Path(path).touch()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ssl_refresher():
|
||||
ssl_context = MockSSLContext()
|
||||
key_path = create_file()
|
||||
cert_path = create_file()
|
||||
ca_path = create_file()
|
||||
ssl_refresher = SSLCertRefresher(ssl_context, key_path, cert_path, ca_path)
|
||||
await asyncio.sleep(1)
|
||||
assert ssl_context.load_cert_chain_count == 0
|
||||
assert ssl_context.load_ca_count == 0
|
||||
|
||||
touch_file(key_path)
|
||||
await asyncio.sleep(1)
|
||||
assert ssl_context.load_cert_chain_count == 1
|
||||
assert ssl_context.load_ca_count == 0
|
||||
|
||||
touch_file(cert_path)
|
||||
touch_file(ca_path)
|
||||
await asyncio.sleep(1)
|
||||
assert ssl_context.load_cert_chain_count == 2
|
||||
assert ssl_context.load_ca_count == 1
|
||||
|
||||
ssl_refresher.stop()
|
||||
|
||||
touch_file(cert_path)
|
||||
touch_file(ca_path)
|
||||
await asyncio.sleep(1)
|
||||
assert ssl_context.load_cert_chain_count == 2
|
||||
assert ssl_context.load_ca_count == 1
|
|
@ -128,6 +128,7 @@ async def run_server(args: Namespace,
|
|||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=None,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.log_level,
|
||||
|
@ -152,6 +153,11 @@ if __name__ == "__main__":
|
|||
type=str,
|
||||
default=None,
|
||||
help="The CA certificates file")
|
||||
parser.add_argument(
|
||||
"--enable-ssl-refresh",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Refresh SSL Context when SSL certificate files change")
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
|
|
|
@ -12,13 +12,16 @@ from fastapi import FastAPI, Request, Response
|
|||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_process_using_port
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def serve_http(app: FastAPI, sock: Optional[socket.socket],
|
||||
async def serve_http(app: FastAPI,
|
||||
sock: Optional[socket.socket],
|
||||
enable_ssl_refresh: bool = False,
|
||||
**uvicorn_kwargs: Any):
|
||||
logger.info("Available routes are:")
|
||||
for route in app.routes:
|
||||
|
@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
|
|||
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
||||
|
||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||
config.load()
|
||||
server = uvicorn.Server(config)
|
||||
_add_shutdown_handlers(app, server)
|
||||
|
||||
|
@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
|
|||
server_task = loop.create_task(
|
||||
server.serve(sockets=[sock] if sock else None))
|
||||
|
||||
ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher(
|
||||
ssl_context=config.ssl,
|
||||
key_path=config.ssl_keyfile,
|
||||
cert_path=config.ssl_certfile,
|
||||
ca_path=config.ssl_ca_certs)
|
||||
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
if ssl_cert_refresher:
|
||||
ssl_cert_refresher.stop()
|
||||
|
||||
async def dummy_shutdown() -> None:
|
||||
pass
|
||||
|
|
|
@ -960,6 +960,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=sock,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
|
|
|
@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||
type=nullable_str,
|
||||
default=None,
|
||||
help="The CA certificates file.")
|
||||
parser.add_argument(
|
||||
"--enable-ssl-refresh",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Refresh SSL Context when SSL certificate files change")
|
||||
parser.add_argument(
|
||||
"--ssl-cert-reqs",
|
||||
type=int,
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from ssl import SSLContext
|
||||
from typing import Callable, Optional
|
||||
|
||||
from watchfiles import Change, awatch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SSLCertRefresher:
|
||||
"""A class that monitors SSL certificate files and
|
||||
reloads them when they change.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ssl_context: SSLContext,
|
||||
key_path: Optional[str] = None,
|
||||
cert_path: Optional[str] = None,
|
||||
ca_path: Optional[str] = None) -> None:
|
||||
self.ssl = ssl_context
|
||||
self.key_path = key_path
|
||||
self.cert_path = cert_path
|
||||
self.ca_path = ca_path
|
||||
|
||||
# Setup certification chain watcher
|
||||
def update_ssl_cert_chain(change: Change, file_path: str) -> None:
|
||||
logger.info("Reloading SSL certificate chain")
|
||||
assert self.key_path and self.cert_path
|
||||
self.ssl.load_cert_chain(self.cert_path, self.key_path)
|
||||
|
||||
self.watch_ssl_cert_task = None
|
||||
if self.key_path and self.cert_path:
|
||||
self.watch_ssl_cert_task = asyncio.create_task(
|
||||
self._watch_files([self.key_path, self.cert_path],
|
||||
update_ssl_cert_chain))
|
||||
|
||||
# Setup CA files watcher
|
||||
def update_ssl_ca(change: Change, file_path: str) -> None:
|
||||
logger.info("Reloading SSL CA certificates")
|
||||
assert self.ca_path
|
||||
self.ssl.load_verify_locations(self.ca_path)
|
||||
|
||||
self.watch_ssl_ca_task = None
|
||||
if self.ca_path:
|
||||
self.watch_ssl_ca_task = asyncio.create_task(
|
||||
self._watch_files([self.ca_path], update_ssl_ca))
|
||||
|
||||
async def _watch_files(self, paths, fun: Callable[[Change, str],
|
||||
None]) -> None:
|
||||
"""Watch multiple file paths asynchronously."""
|
||||
logger.info("SSLCertRefresher monitors files: %s", paths)
|
||||
async for changes in awatch(*paths):
|
||||
try:
|
||||
for change, file_path in changes:
|
||||
logger.info("File change detected: %s - %s", change.name,
|
||||
file_path)
|
||||
fun(change, file_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"SSLCertRefresher failed taking action on file change. "
|
||||
"Error: %s", e)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop watching files."""
|
||||
if self.watch_ssl_cert_task:
|
||||
self.watch_ssl_cert_task.cancel()
|
||||
self.watch_ssl_cert_task = None
|
||||
if self.watch_ssl_ca_task:
|
||||
self.watch_ssl_ca_task.cancel()
|
||||
self.watch_ssl_ca_task = None
|
Loading…
Reference in New Issue