Rename servers and change port numbers to reduce confusion (#149)

This commit is contained in:
Zhuohan Li 2023-06-17 00:13:02 +08:00 committed by GitHub
parent 311490a720
commit eedb46bf03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 41 additions and 37 deletions

View File

@ -52,7 +52,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument("--n-threads", type=int, default=128)
args = parser.parse_args()

View File

@ -2,7 +2,7 @@ from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput, CompletionOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import initialize_cluster
__version__ = "0.1.0"
@ -12,7 +12,7 @@ __all__ = [
"SamplingParams",
"RequestOutput",
"CompletionOutput",
"LLMServer",
"LLMEngine",
"ServerArgs",
"initialize_cluster",
]

View File

@ -8,7 +8,7 @@ import uvicorn
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
@ -18,7 +18,7 @@ app = FastAPI()
@app.post("/generate")
async def generate(request: Request) -> Response:
""" Stream the results of the generation request.
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
@ -74,12 +74,12 @@ async def generate(request: Request) -> Response:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--port", type=int, default=8000)
parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
server = AsyncLLMEngine.from_server_args(server_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -6,7 +6,7 @@ from tqdm import tqdm
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.llm_server import LLMEngine
from cacheflow.utils import Counter
@ -20,7 +20,7 @@ class LLM:
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMServer` class instead.
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
Args:
@ -52,7 +52,7 @@ class LLM:
seed=seed,
**kwargs,
)
self.llm_server = LLMServer.from_server_args(server_args)
self.llm_server = LLMEngine.from_server_args(server_args)
self.request_counter = Counter()
def get_tokenizer(

View File

@ -15,7 +15,7 @@ import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
@ -319,7 +319,7 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
server = AsyncLLMEngine.from_server_args(server_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)

View File

@ -6,7 +6,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import ray, initialize_cluster
logger = init_logger(__name__)
@ -14,26 +14,26 @@ logger = init_logger(__name__)
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMServer:
"""An asynchronous wrapper for LLMServer.
class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
This class is used to wrap the LLMServer class to make it asynchronous. It
This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMServer is kicked by the generate method when there
requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMServer to the caller.
from the LLMEngine to the caller.
NOTE: For the comprehensive list of arguments, see `LLMServer`.
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
server_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMServer.
*args, *kwargs: Arguments for LLMEngine.
"""
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None:
@ -41,11 +41,11 @@ class AsyncLLMServer:
self.server_use_ray = server_use_ray
self.log_requests = log_requests
if not self.server_use_ray:
server_class = LLMServer
server_class = LLMEngine
elif self.worker_use_ray:
server_class = ray.remote(num_cpus=0)(LLMServer).remote
server_class = ray.remote(num_cpus=0)(LLMEngine).remote
else:
server_class = ray.remote(num_gpus=1)(LLMServer).remote
server_class = ray.remote(num_gpus=1)(LLMEngine).remote
self.server = server_class(*args, **kwargs)
# Request id -> request output.
self.request_outputs: Dict[str, RequestOutput] = {}
@ -85,8 +85,8 @@ class AsyncLLMServer:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMServer and streams the outputs
from the LLMServer to the caller.
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
@ -97,7 +97,7 @@ class AsyncLLMServer:
use the tokenizer to convert the prompts to token IDs.
Yields:
The output `RequestOutput` objects from the LLMServer for the
The output `RequestOutput` objects from the LLMEngine for the
request.
"""
# Preprocess the request.
@ -200,7 +200,7 @@ class AsyncLLMServer:
self.kicking_request_id = None
@classmethod
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine":
"""Creates an async LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()

View File

@ -18,7 +18,7 @@ from cacheflow.worker.worker import Worker
logger = init_logger(__name__)
class LLMServer:
class LLMEngine:
"""An LLM server that receives requests and generates texts.
This is the main class for the CacheFlow LLM server. It receives requests
@ -29,7 +29,7 @@ class LLMServer:
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMServer` class wraps this class for online serving.
`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `ServerArgs` class. For the
comprehensive list of arguments, see `ServerArgs`.
@ -135,7 +135,7 @@ class LLMServer:
self._run_workers("init_cache_engine", cache_config=self.cache_config)
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
def from_server_args(cls, server_args: ServerArgs) -> "LLMEngine":
"""Creates an LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()

View File

@ -1,3 +1,5 @@
"""Example Python client for cacheflow.entrypoints.api_server"""
import argparse
import json
from typing import Iterable, List
@ -45,7 +47,7 @@ def get_response(response: requests.Response) -> List[str]:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--n", type=int, default=4)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")

View File

@ -9,6 +9,7 @@ def http_bot(prompt):
headers = {"User-Agent": "Cacheflow Client"}
pload = {
"prompt": prompt,
"stream": True,
"max_tokens": 128,
}
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
@ -34,8 +35,8 @@ def build_demo():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8002)
parser.add_argument("--model-url", type=str, default="http://localhost:8001/generate")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate")
args = parser.parse_args()
demo = build_demo()

View File

@ -1,12 +1,12 @@
import argparse
from cacheflow import ServerArgs, LLMServer, SamplingParams
from cacheflow import ServerArgs, LLMEngine, SamplingParams
def main(args: argparse.Namespace):
# Parse the CLI argument and initialize the server.
server_args = ServerArgs.from_cli_args(args)
server = LLMServer.from_server_args(server_args)
server = LLMEngine.from_server_args(server_args)
# Test the following prompts.
test_prompts = [
@ -38,7 +38,8 @@ def main(args: argparse.Namespace):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Simple CacheFlow server.')
parser = argparse.ArgumentParser(
description='Demo on using the LLMEngine class synchronously')
parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)