mirror of https://github.com/vllm-project/vllm
Rename servers and change port numbers to reduce confusion (#149)
This commit is contained in:
parent
311490a720
commit
eedb46bf03
|
@ -52,7 +52,7 @@ def main(args: argparse.Namespace):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--max-tokens", type=int, default=128)
|
||||
parser.add_argument("--n-threads", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -2,7 +2,7 @@ from cacheflow.entrypoints.llm import LLM
|
|||
from cacheflow.outputs import RequestOutput, CompletionOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
from cacheflow.server.llm_server import LLMEngine
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
@ -12,7 +12,7 @@ __all__ = [
|
|||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"LLMServer",
|
||||
"LLMEngine",
|
||||
"ServerArgs",
|
||||
"initialize_cluster",
|
||||
]
|
||||
|
|
|
@ -8,7 +8,7 @@ import uvicorn
|
|||
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||
from cacheflow.server.async_llm_server import AsyncLLMEngine
|
||||
from cacheflow.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
|
@ -18,7 +18,7 @@ app = FastAPI()
|
|||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
""" Stream the results of the generation request.
|
||||
"""Generate completion for the request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
|
@ -74,12 +74,12 @@ async def generate(request: Request) -> Response:
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMServer.from_server_args(server_args)
|
||||
server = AsyncLLMEngine.from_server_args(server_args)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
|
@ -6,7 +6,7 @@ from tqdm import tqdm
|
|||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
from cacheflow.server.llm_server import LLMEngine
|
||||
from cacheflow.utils import Counter
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ class LLM:
|
|||
mechanism and efficient memory management.
|
||||
|
||||
NOTE: This class is intended to be used for offline inference. For online
|
||||
serving, use the `AsyncLLMServer` class instead.
|
||||
serving, use the `AsyncLLMEngine` class instead.
|
||||
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
|
||||
|
||||
Args:
|
||||
|
@ -52,7 +52,7 @@ class LLM:
|
|||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_server = LLMServer.from_server_args(server_args)
|
||||
self.llm_server = LLMEngine.from_server_args(server_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
|
|
|
@ -15,7 +15,7 @@ import uvicorn
|
|||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||
from cacheflow.server.async_llm_server import AsyncLLMEngine
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
|
@ -319,7 +319,7 @@ if __name__ == "__main__":
|
|||
served_model = args.served_model_name or args.model
|
||||
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMServer.from_server_args(server_args)
|
||||
server = AsyncLLMEngine.from_server_args(server_args)
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(args.model)
|
|
@ -6,7 +6,7 @@ from cacheflow.logger import init_logger
|
|||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
from cacheflow.server.llm_server import LLMEngine
|
||||
from cacheflow.server.ray_utils import ray, initialize_cluster
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -14,26 +14,26 @@ logger = init_logger(__name__)
|
|||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
|
||||
|
||||
class AsyncLLMServer:
|
||||
"""An asynchronous wrapper for LLMServer.
|
||||
class AsyncLLMEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
|
||||
This class is used to wrap the LLMServer class to make it asynchronous. It
|
||||
This class is used to wrap the LLMEngine class to make it asynchronous. It
|
||||
uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMServer is kicked by the generate method when there
|
||||
requests. The LLMEngine is kicked by the generate method when there
|
||||
are requests in the waiting queue. The generate method yields the outputs
|
||||
from the LLMServer to the caller.
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
NOTE: For the comprehensive list of arguments, see `LLMServer`.
|
||||
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
|
||||
|
||||
Args:
|
||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
distributed execution. Should be the same as
|
||||
`parallel_config.worker_use_ray`.
|
||||
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
|
||||
server_use_ray: Whether to make LLMEngine a Ray actor. If so, the
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
*args, *kwargs: Arguments for LLMServer.
|
||||
*args, *kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
|
||||
log_requests: bool = True, *args, **kwargs) -> None:
|
||||
|
@ -41,11 +41,11 @@ class AsyncLLMServer:
|
|||
self.server_use_ray = server_use_ray
|
||||
self.log_requests = log_requests
|
||||
if not self.server_use_ray:
|
||||
server_class = LLMServer
|
||||
server_class = LLMEngine
|
||||
elif self.worker_use_ray:
|
||||
server_class = ray.remote(num_cpus=0)(LLMServer).remote
|
||||
server_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
||||
else:
|
||||
server_class = ray.remote(num_gpus=1)(LLMServer).remote
|
||||
server_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
||||
self.server = server_class(*args, **kwargs)
|
||||
# Request id -> request output.
|
||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||
|
@ -85,8 +85,8 @@ class AsyncLLMServer:
|
|||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMServer and streams the outputs
|
||||
from the LLMServer to the caller.
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
|
@ -97,7 +97,7 @@ class AsyncLLMServer:
|
|||
use the tokenizer to convert the prompts to token IDs.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMServer for the
|
||||
The output `RequestOutput` objects from the LLMEngine for the
|
||||
request.
|
||||
"""
|
||||
# Preprocess the request.
|
||||
|
@ -200,7 +200,7 @@ class AsyncLLMServer:
|
|||
self.kicking_request_id = None
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
|
||||
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM server from the server arguments."""
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
|
|
|
@ -18,7 +18,7 @@ from cacheflow.worker.worker import Worker
|
|||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LLMServer:
|
||||
class LLMEngine:
|
||||
"""An LLM server that receives requests and generates texts.
|
||||
|
||||
This is the main class for the CacheFlow LLM server. It receives requests
|
||||
|
@ -29,7 +29,7 @@ class LLMServer:
|
|||
serving throughput.
|
||||
|
||||
The `LLM` class wraps this class for offline batched inference and the
|
||||
`AsyncLLMServer` class wraps this class for online serving.
|
||||
`AsyncLLMEngine` class wraps this class for online serving.
|
||||
|
||||
NOTE: The config arguments are derived from the `ServerArgs` class. For the
|
||||
comprehensive list of arguments, see `ServerArgs`.
|
||||
|
@ -135,7 +135,7 @@ class LLMServer:
|
|||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
|
||||
def from_server_args(cls, server_args: ServerArgs) -> "LLMEngine":
|
||||
"""Creates an LLM server from the server arguments."""
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
"""Example Python client for cacheflow.entrypoints.api_server"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from typing import Iterable, List
|
||||
|
@ -45,7 +47,7 @@ def get_response(response: requests.Response) -> List[str]:
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--n", type=int, default=4)
|
||||
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
||||
parser.add_argument("--stream", action="store_true")
|
|
@ -9,6 +9,7 @@ def http_bot(prompt):
|
|||
headers = {"User-Agent": "Cacheflow Client"}
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"max_tokens": 128,
|
||||
}
|
||||
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
|
||||
|
@ -34,8 +35,8 @@ def build_demo():
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8002)
|
||||
parser.add_argument("--model-url", type=str, default="http://localhost:8001/generate")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate")
|
||||
args = parser.parse_args()
|
||||
|
||||
demo = build_demo()
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import argparse
|
||||
|
||||
from cacheflow import ServerArgs, LLMServer, SamplingParams
|
||||
from cacheflow import ServerArgs, LLMEngine, SamplingParams
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# Parse the CLI argument and initialize the server.
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server = LLMServer.from_server_args(server_args)
|
||||
server = LLMEngine.from_server_args(server_args)
|
||||
|
||||
# Test the following prompts.
|
||||
test_prompts = [
|
||||
|
@ -38,7 +38,8 @@ def main(args: argparse.Namespace):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Simple CacheFlow server.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Demo on using the LLMEngine class synchronously')
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Loading…
Reference in New Issue