Rename servers to engines (#152)

This commit is contained in:
Zhuohan Li 2023-06-17 17:25:21 +08:00 committed by GitHub
parent bab8f3dd0d
commit e5464ee484
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 165 additions and 174 deletions

View File

@ -14,7 +14,7 @@ def main(args: argparse.Namespace):
# Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the server will automatically process the request in multiple batches.
# the engine will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
tensor_parallel_size=args.tensor_parallel_size,

View File

@ -2,7 +2,7 @@
On the server side, run one of the following commands:
(CacheFlow backend)
python -m cacheflow.entrypoints.simple_fastapi_frontend \
python -m cacheflow.entrypoints.api_server \
--disable-log-requests --model <your_model>
(TGI backend)

View File

@ -84,7 +84,7 @@ def run_cacheflow(
seed=seed,
)
# Add the requests to the server.
# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
@ -103,7 +103,7 @@ def run_cacheflow(
start = time.time()
# FIXME(woosuk): Do use internal method.
llm._run_server(use_tqdm=True)
llm._run_engine(use_tqdm=True)
end = time.time()
return end - start

View File

@ -1,9 +1,9 @@
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.engine.ray_utils import initialize_cluster
from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput, CompletionOutput
from cacheflow.outputs import CompletionOutput, RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import initialize_cluster
__version__ = "0.1.0"
@ -13,6 +13,6 @@ __all__ = [
"RequestOutput",
"CompletionOutput",
"LLMEngine",
"ServerArgs",
"EngineArgs",
"initialize_cluster",
]

View File

@ -216,7 +216,7 @@ class Scheduler:
if not self.log_stats:
return scheduler_outputs, prompt_group_ids
# TODO(woosuk): Move the below code to server.
# TODO(woosuk): Move the below code to the engine.
now = time.time()
if num_batched_tokens > 0:
self.num_input_tokens.append((now, num_batched_tokens))

View File

@ -8,8 +8,8 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
@dataclass
class ServerArgs:
"""Arguments for CacheFlow servers."""
class EngineArgs:
"""Arguments for CacheFlow engine."""
model: str
download_dir: Optional[str] = None
use_np_weights: bool = False
@ -33,12 +33,12 @@ class ServerArgs:
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""Shared CLI arguments for CacheFlow servers."""
"""Shared CLI arguments for CacheFlow engine."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str,
default=ServerArgs.download_dir,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
@ -49,7 +49,7 @@ class ServerArgs:
parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype,
choices=['auto', 'half', 'bfloat16', 'float'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
@ -60,46 +60,46 @@ class ServerArgs:
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size,
default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size,
default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int,
default=ServerArgs.block_size,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
parser.add_argument('--seed', type=int, default=EngineArgs.seed,
help='random seed')
parser.add_argument('--swap-space', type=int,
default=ServerArgs.swap_space,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float,
default=ServerArgs.gpu_memory_utilization,
default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for'
'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens,
default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics')
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return server_args
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_server_configs(
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
@ -117,19 +117,19 @@ class ServerArgs:
@dataclass
class AsyncServerArgs(ServerArgs):
"""Arguments for asynchronous CacheFlow servers."""
server_use_ray: bool = False
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous CacheFlow engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
parser = ServerArgs.add_cli_args(parser)
parser.add_argument('--server-use-ray', action='store_true',
help='use Ray to start the LLM server in a '
'separate process as the web server process.')
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', action='store_true',
help='use Ray to start the LLM engine in a '
'separate process as the server process.')
parser.add_argument('--disable-log-requests', action='store_true',
help='disable logging requests')
return parser

View File

@ -2,12 +2,12 @@ import asyncio
import time
from typing import Dict, List, Optional
from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.engine.ray_utils import initialize_cluster, ray
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import ray, initialize_cluster
logger = init_logger(__name__)
@ -29,44 +29,44 @@ class AsyncLLMEngine:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
server_use_ray: Whether to make LLMEngine a Ray actor. If so, the
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMEngine.
"""
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.server_use_ray = server_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
if not self.server_use_ray:
server_class = LLMEngine
if not self.engine_use_ray:
engine_class = LLMEngine
elif self.worker_use_ray:
server_class = ray.remote(num_cpus=0)(LLMEngine).remote
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
else:
server_class = ray.remote(num_gpus=1)(LLMEngine).remote
self.server = server_class(*args, **kwargs)
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
self.engine = engine_class(*args, **kwargs)
# Request id -> request output.
self.request_outputs: Dict[str, RequestOutput] = {}
# Request id -> event to notify that there is new output.
self.request_events: Dict[str, asyncio.Event] = {}
self.is_server_running = False
self.is_engine_running = False
self.kicking_request_id: Optional[str] = None
async def server_step(self, kicking_request_id: Optional[str] = None):
"""Kick the server to process the waiting requests."""
self.is_server_running = True
async def engine_step(self, kicking_request_id: Optional[str] = None):
"""Kick the engine to process the waiting requests."""
self.is_engine_running = True
self.kicking_request_id = kicking_request_id
if self.server_use_ray:
request_outputs = await self.server.step.remote()
if self.engine_use_ray:
request_outputs = await self.engine.step.remote()
else:
# Yield to the event loop to allow other coroutines to run
# while is_server_running is True. This let the server to add new
# while is_engine_running is True. This let the engine to add new
# requests into the queue.
await asyncio.sleep(0)
request_outputs = self.server.step()
self.is_server_running = False
request_outputs = self.engine.step()
self.is_engine_running = False
self.kicking_request_id = None
# Notify the waiting coroutines that there are new outputs ready.
@ -104,7 +104,7 @@ class AsyncLLMEngine:
arrival_time = time.time()
# Create an event to notify us that there is new output from the
# cacheflow server.
# cacheflow engine.
request_event = asyncio.Event()
self.request_events[request_id] = request_event
@ -114,31 +114,31 @@ class AsyncLLMEngine:
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")
# Add the request into the cacheflow server's waiting queue.
if self.server_use_ray:
await self.server.add_request.remote(
# Add the request into the cacheflow engine's waiting queue.
if self.engine_use_ray:
await self.engine.add_request.remote(
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else:
self.server.add_request(
self.engine.add_request(
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
# The cacheflow server does not have a background loop that keeps
# The cacheflow engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
# the engine to process the requests.
while True:
if request_id not in self.request_events:
# The request has been aborted.
return
# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step(request_id)
# Kick the engine if the engine is not running.
if not self.is_engine_running:
await self.engine_step(request_id)
# Wait for new output. The group_event will be set in server_step
# Wait for new output. The group_event will be set in engine_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
try:
@ -160,11 +160,11 @@ class AsyncLLMEngine:
del self.request_outputs[request_id]
del self.request_events[request_id]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# Kick the engine if the engine is not running. This is to
# prevent that there are still requests in engine's waiting
# queue to be executed.
if not self.is_server_running:
await self.server_step()
if not self.is_engine_running:
await self.engine_step()
break
async def abort(self, request_id: str) -> None:
@ -183,36 +183,36 @@ class AsyncLLMEngine:
if self.log_requests:
logger.info(f"Aborted request {request_id}.")
if self.server_use_ray:
await self.server.abort_request.remote(request_id)
if self.engine_use_ray:
await self.engine.abort_request.remote(request_id)
else:
self.server.abort_request(request_id)
self.engine.abort_request(request_id)
if request_id in self.request_events:
del self.request_events[request_id]
if request_id in self.request_outputs:
del self.request_outputs[request_id]
# To prevent deadlock when a request is aborted while the server is
# To prevent deadlock when a request is aborted while the engine is
# running.
if self.kicking_request_id == request_id:
self.is_server_running = False
self.is_engine_running = False
self.kicking_request_id = None
@classmethod
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine":
"""Creates an async LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(
parallel_config, server_args.server_use_ray)
# Create the LLM server.
server = cls(server_args.worker_use_ray,
server_args.server_use_ray,
not server_args.disable_log_requests,
*server_configs,
parallel_config, engine_args.engine_use_ray)
# Create the async LLM engine.
engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray,
not engine_args.disable_log_requests,
*engine_configs,
distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)
return server
log_stats=not engine_args.disable_log_stats)
return engine

View File

@ -4,13 +4,13 @@ from typing import Any, List, Optional
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.core.scheduler import Scheduler
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.engine.tokenizer_utils import (detokenize_incrementally,
get_tokenizer)
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker
@ -19,9 +19,9 @@ logger = init_logger(__name__)
class LLMEngine:
"""An LLM server that receives requests and generates texts.
"""An LLM engine that receives requests and generates texts.
This is the main class for the CacheFlow LLM server. It receives requests
This is the main class for the CacheFlow LLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
@ -31,8 +31,8 @@ class LLMEngine:
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `ServerArgs` class. For the
comprehensive list of arguments, see `ServerArgs`.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
Args:
model_config: The configuration related to the LLM model.
@ -58,7 +58,7 @@ class LLMEngine:
log_stats: bool,
) -> None:
logger.info(
"Initializing an LLM server with config: "
"Initializing an LLM engine with config: "
f"model={model_config.model!r}, "
f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
@ -135,17 +135,17 @@ class LLMEngine:
self._run_workers("init_cache_engine", cache_config=self.cache_config)
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMEngine":
"""Creates an LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = cls(*server_configs, distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)
return server
# Create the LLM engine.
engine = cls(*engine_configs, distributed_init_method, devices,
log_stats=not engine_args.disable_log_stats)
return engine
def add_request(
self,
@ -155,10 +155,10 @@ class LLMEngine:
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
) -> None:
"""Add a request to the server's request pool.
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
scheduler as `server.step()` is called. The exact scheduling policy is
scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
@ -211,7 +211,7 @@ class LLMEngine:
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration for the server. It first
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes

View File

@ -13,15 +13,15 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def initialize_cluster(
parallel_config: ParallelConfig,
server_use_ray: bool = False,
ray_server_address: Optional[str] = None,
engine_use_ray: bool = False,
ray_address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]:
"""Initialize the distributed cluster probably with Ray.
Args:
parallel_config: The configurations for parallel execution.
server_use_ray: Whether to use Ray for async server.
ray_server_address: The address of the Ray cluster. If None, uses
engine_use_ray: Whether to use Ray for async engine.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
Returns:
@ -31,13 +31,13 @@ def initialize_cluster(
each worker in each pipeline stage. Each device ID is a tuple of
(rank, node resource, device id).
"""
if parallel_config.worker_use_ray or server_use_ray:
if parallel_config.worker_use_ray or engine_use_ray:
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=ray_server_address)
ray.init(address=ray_address)
if not parallel_config.worker_use_ray:
# Initialize cluster locally.

View File

@ -6,9 +6,9 @@ from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import Response, StreamingResponse
import uvicorn
from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
@ -30,7 +30,7 @@ async def generate(request: Request) -> Response:
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = server.generate(prompt, sampling_params, request_id)
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
@ -44,7 +44,7 @@ async def generate(request: Request) -> Response:
yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await server.abort(request_id)
await engine.abort(request_id)
if stream:
background_tasks = BackgroundTasks()
@ -57,7 +57,7 @@ async def generate(request: Request) -> Response:
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
await engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
@ -75,11 +75,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncServerArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMEngine.from_server_args(server_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -1,12 +1,12 @@
from typing import List, Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.utils import Counter
@ -21,7 +21,7 @@ class LLM:
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
@ -45,20 +45,20 @@ class LLM:
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
server_args = ServerArgs(
engine_args = EngineArgs(
model=model,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
**kwargs,
)
self.llm_server = LLMEngine.from_server_args(server_args)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
self.request_counter = Counter()
def get_tokenizer(
self,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_server.tokenizer
return self.llm_engine.tokenizer
def generate(
self,
@ -99,7 +99,7 @@ class LLM:
# Use default sampling params.
sampling_params = SamplingParams()
# Add requests to the server.
# Add requests to the engine.
if prompts is not None:
num_requests = len(prompts)
else:
@ -111,7 +111,7 @@ class LLM:
else:
token_ids = prompt_token_ids[i]
self._add_request(prompt, sampling_params, token_ids)
return self._run_server(use_tqdm)
return self._run_engine(use_tqdm)
def _add_request(
self,
@ -120,18 +120,18 @@ class LLM:
prompt_token_ids: Optional[List[int]],
) -> None:
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params,
self.llm_engine.add_request(request_id, prompt, sampling_params,
prompt_token_ids)
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_server.get_num_unfinished_requests()
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the server.
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests():
step_outputs = self.llm_server.step()
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished():
outputs.append(output)

View File

@ -10,29 +10,20 @@ import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
from cacheflow.engine.tokenizer_utils import get_tokenizer
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import random_uuid
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
TIMEOUT_KEEP_ALIVE = 5 # seconds
@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the cacheflow server does not currently support
- echo (since the cacheflow engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported in cacheflow server)
- logit_bias (to be supported in cacheflow engine)
"""
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")
@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
return error_check_ret
if request.echo:
# We do not support echo since the cacheflow server does not
# We do not support echo since the cacheflow engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
"suffix is not currently supported")
if request.logit_bias is not None:
# TODO: support logit_bias in cacheflow server.
# TODO: support logit_bias in cacheflow engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = server.generate(prompt, sampling_params,
result_generator = engine.generate(prompt, sampling_params,
request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
not request.use_beam_search)
async def abort_request() -> None:
await server.abort(request_id)
await engine.abort(request_id)
def create_stream_response_json(index: int,
text: str,
@ -303,7 +294,7 @@ if __name__ == "__main__":
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = AsyncServerArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware(
@ -318,8 +309,8 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMEngine.from_server_args(server_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)

View File

@ -1,12 +1,12 @@
import argparse
from cacheflow import ServerArgs, LLMEngine, SamplingParams
from cacheflow import EngineArgs, LLMEngine, SamplingParams
def main(args: argparse.Namespace):
# Parse the CLI argument and initialize the server.
server_args = ServerArgs.from_cli_args(args)
server = LLMEngine.from_server_args(server_args)
# Parse the CLI argument and initialize the engine.
engine_args = EngineArgs.from_cli_args(args)
engine = LLMEngine.from_engine_args(engine_args)
# Test the following prompts.
test_prompts = [
@ -19,27 +19,27 @@ def main(args: argparse.Namespace):
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
]
# Run the server by calling `server.step()` manually.
# Run the engine by calling `engine.step()` manually.
request_id = 0
while True:
# To test iteration-level scheduling, we add one request at each step.
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
server.add_request(str(request_id), prompt, sampling_params)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1
request_outputs = server.step()
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished():
print(request_output)
if not (server.has_unfinished_requests() or test_prompts):
if not (engine.has_unfinished_requests() or test_prompts):
break
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Demo on using the LLMEngine class synchronously')
parser = ServerArgs.add_cli_args(parser)
description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)