mirror of https://github.com/vllm-project/vllm
Rename servers to engines (#152)
This commit is contained in:
parent
bab8f3dd0d
commit
e5464ee484
|
@ -14,7 +14,7 @@ def main(args: argparse.Namespace):
|
|||
|
||||
# Process all the requests in a single batch if possible.
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the server will automatically process the request in multiple batches.
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
On the server side, run one of the following commands:
|
||||
(CacheFlow backend)
|
||||
python -m cacheflow.entrypoints.simple_fastapi_frontend \
|
||||
python -m cacheflow.entrypoints.api_server \
|
||||
--disable-log-requests --model <your_model>
|
||||
|
||||
(TGI backend)
|
||||
|
|
|
@ -84,7 +84,7 @@ def run_cacheflow(
|
|||
seed=seed,
|
||||
)
|
||||
|
||||
# Add the requests to the server.
|
||||
# Add the requests to the engine.
|
||||
for prompt, _, output_len in requests:
|
||||
sampling_params = SamplingParams(
|
||||
n=n,
|
||||
|
@ -103,7 +103,7 @@ def run_cacheflow(
|
|||
|
||||
start = time.time()
|
||||
# FIXME(woosuk): Do use internal method.
|
||||
llm._run_server(use_tqdm=True)
|
||||
llm._run_engine(use_tqdm=True)
|
||||
end = time.time()
|
||||
return end - start
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from cacheflow.engine.arg_utils import EngineArgs
|
||||
from cacheflow.engine.llm_engine import LLMEngine
|
||||
from cacheflow.engine.ray_utils import initialize_cluster
|
||||
from cacheflow.entrypoints.llm import LLM
|
||||
from cacheflow.outputs import RequestOutput, CompletionOutput
|
||||
from cacheflow.outputs import CompletionOutput, RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMEngine
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
|
@ -13,6 +13,6 @@ __all__ = [
|
|||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"LLMEngine",
|
||||
"ServerArgs",
|
||||
"EngineArgs",
|
||||
"initialize_cluster",
|
||||
]
|
||||
|
|
|
@ -216,7 +216,7 @@ class Scheduler:
|
|||
if not self.log_stats:
|
||||
return scheduler_outputs, prompt_group_ids
|
||||
|
||||
# TODO(woosuk): Move the below code to server.
|
||||
# TODO(woosuk): Move the below code to the engine.
|
||||
now = time.time()
|
||||
if num_batched_tokens > 0:
|
||||
self.num_input_tokens.append((now, num_batched_tokens))
|
||||
|
|
|
@ -8,8 +8,8 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|||
|
||||
|
||||
@dataclass
|
||||
class ServerArgs:
|
||||
"""Arguments for CacheFlow servers."""
|
||||
class EngineArgs:
|
||||
"""Arguments for CacheFlow engine."""
|
||||
model: str
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
|
@ -33,12 +33,12 @@ class ServerArgs:
|
|||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
"""Shared CLI arguments for CacheFlow servers."""
|
||||
"""Shared CLI arguments for CacheFlow engine."""
|
||||
# Model arguments
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
parser.add_argument('--download-dir', type=str,
|
||||
default=ServerArgs.download_dir,
|
||||
default=EngineArgs.download_dir,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of '
|
||||
'huggingface')
|
||||
|
@ -49,7 +49,7 @@ class ServerArgs:
|
|||
parser.add_argument('--use-dummy-weights', action='store_true',
|
||||
help='use dummy values for model weights')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
||||
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype,
|
||||
choices=['auto', 'half', 'bfloat16', 'float'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
|
@ -60,46 +60,46 @@ class ServerArgs:
|
|||
help='use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
||||
default=ServerArgs.pipeline_parallel_size,
|
||||
default=EngineArgs.pipeline_parallel_size,
|
||||
help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
||||
default=ServerArgs.tensor_parallel_size,
|
||||
default=EngineArgs.tensor_parallel_size,
|
||||
help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int,
|
||||
default=ServerArgs.block_size,
|
||||
default=EngineArgs.block_size,
|
||||
choices=[8, 16, 32],
|
||||
help='token block size')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
||||
parser.add_argument('--seed', type=int, default=EngineArgs.seed,
|
||||
help='random seed')
|
||||
parser.add_argument('--swap-space', type=int,
|
||||
default=ServerArgs.swap_space,
|
||||
default=EngineArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--gpu-memory-utilization', type=float,
|
||||
default=ServerArgs.gpu_memory_utilization,
|
||||
default=EngineArgs.gpu_memory_utilization,
|
||||
help='the percentage of GPU memory to be used for'
|
||||
'the model executor')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int,
|
||||
default=ServerArgs.max_num_batched_tokens,
|
||||
default=EngineArgs.max_num_batched_tokens,
|
||||
help='maximum number of batched tokens per '
|
||||
'iteration')
|
||||
parser.add_argument('--max-num-seqs', type=int,
|
||||
default=ServerArgs.max_num_seqs,
|
||||
default=EngineArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--disable-log-stats', action='store_true',
|
||||
help='disable logging statistics')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
# Set the attributes from the parsed arguments.
|
||||
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
return server_args
|
||||
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
return engine_args
|
||||
|
||||
def create_server_configs(
|
||||
def create_engine_configs(
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||
# Initialize the configs.
|
||||
|
@ -117,19 +117,19 @@ class ServerArgs:
|
|||
|
||||
|
||||
@dataclass
|
||||
class AsyncServerArgs(ServerArgs):
|
||||
"""Arguments for asynchronous CacheFlow servers."""
|
||||
server_use_ray: bool = False
|
||||
class AsyncEngineArgs(EngineArgs):
|
||||
"""Arguments for asynchronous CacheFlow engine."""
|
||||
engine_use_ray: bool = False
|
||||
disable_log_requests: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
parser.add_argument('--server-use-ray', action='store_true',
|
||||
help='use Ray to start the LLM server in a '
|
||||
'separate process as the web server process.')
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument('--engine-use-ray', action='store_true',
|
||||
help='use Ray to start the LLM engine in a '
|
||||
'separate process as the server process.')
|
||||
parser.add_argument('--disable-log-requests', action='store_true',
|
||||
help='disable logging requests')
|
||||
return parser
|
|
@ -2,12 +2,12 @@ import asyncio
|
|||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||
from cacheflow.engine.llm_engine import LLMEngine
|
||||
from cacheflow.engine.ray_utils import initialize_cluster, ray
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.llm_server import LLMEngine
|
||||
from cacheflow.server.ray_utils import ray, initialize_cluster
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -29,44 +29,44 @@ class AsyncLLMEngine:
|
|||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
distributed execution. Should be the same as
|
||||
`parallel_config.worker_use_ray`.
|
||||
server_use_ray: Whether to make LLMEngine a Ray actor. If so, the
|
||||
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
*args, *kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
|
||||
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
|
||||
log_requests: bool = True, *args, **kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.server_use_ray = server_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
if not self.server_use_ray:
|
||||
server_class = LLMEngine
|
||||
if not self.engine_use_ray:
|
||||
engine_class = LLMEngine
|
||||
elif self.worker_use_ray:
|
||||
server_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
||||
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
|
||||
else:
|
||||
server_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
||||
self.server = server_class(*args, **kwargs)
|
||||
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
|
||||
self.engine = engine_class(*args, **kwargs)
|
||||
# Request id -> request output.
|
||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||
# Request id -> event to notify that there is new output.
|
||||
self.request_events: Dict[str, asyncio.Event] = {}
|
||||
self.is_server_running = False
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id: Optional[str] = None
|
||||
|
||||
async def server_step(self, kicking_request_id: Optional[str] = None):
|
||||
"""Kick the server to process the waiting requests."""
|
||||
self.is_server_running = True
|
||||
async def engine_step(self, kicking_request_id: Optional[str] = None):
|
||||
"""Kick the engine to process the waiting requests."""
|
||||
self.is_engine_running = True
|
||||
self.kicking_request_id = kicking_request_id
|
||||
if self.server_use_ray:
|
||||
request_outputs = await self.server.step.remote()
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote()
|
||||
else:
|
||||
# Yield to the event loop to allow other coroutines to run
|
||||
# while is_server_running is True. This let the server to add new
|
||||
# while is_engine_running is True. This let the engine to add new
|
||||
# requests into the queue.
|
||||
await asyncio.sleep(0)
|
||||
request_outputs = self.server.step()
|
||||
self.is_server_running = False
|
||||
request_outputs = self.engine.step()
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id = None
|
||||
|
||||
# Notify the waiting coroutines that there are new outputs ready.
|
||||
|
@ -104,7 +104,7 @@ class AsyncLLMEngine:
|
|||
arrival_time = time.time()
|
||||
|
||||
# Create an event to notify us that there is new output from the
|
||||
# cacheflow server.
|
||||
# cacheflow engine.
|
||||
request_event = asyncio.Event()
|
||||
self.request_events[request_id] = request_event
|
||||
|
||||
|
@ -114,31 +114,31 @@ class AsyncLLMEngine:
|
|||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
# Add the request into the cacheflow server's waiting queue.
|
||||
if self.server_use_ray:
|
||||
await self.server.add_request.remote(
|
||||
# Add the request into the cacheflow engine's waiting queue.
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
else:
|
||||
self.server.add_request(
|
||||
self.engine.add_request(
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
# The cacheflow server does not have a background loop that keeps
|
||||
# The cacheflow engine does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
# the server to process the requests.
|
||||
# the engine to process the requests.
|
||||
while True:
|
||||
if request_id not in self.request_events:
|
||||
# The request has been aborted.
|
||||
return
|
||||
|
||||
# Kick the server if the server is not running.
|
||||
if not self.is_server_running:
|
||||
await self.server_step(request_id)
|
||||
# Kick the engine if the engine is not running.
|
||||
if not self.is_engine_running:
|
||||
await self.engine_step(request_id)
|
||||
|
||||
# Wait for new output. The group_event will be set in server_step
|
||||
# Wait for new output. The group_event will be set in engine_step
|
||||
# when there is new output available for the sequence group.
|
||||
# Added a timeout to prevent deadlock.
|
||||
try:
|
||||
|
@ -160,11 +160,11 @@ class AsyncLLMEngine:
|
|||
|
||||
del self.request_outputs[request_id]
|
||||
del self.request_events[request_id]
|
||||
# Kick the server if the server is not running. This is to
|
||||
# prevent that there are still requests in server's waiting
|
||||
# Kick the engine if the engine is not running. This is to
|
||||
# prevent that there are still requests in engine's waiting
|
||||
# queue to be executed.
|
||||
if not self.is_server_running:
|
||||
await self.server_step()
|
||||
if not self.is_engine_running:
|
||||
await self.engine_step()
|
||||
break
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
|
@ -183,36 +183,36 @@ class AsyncLLMEngine:
|
|||
if self.log_requests:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
if self.server_use_ray:
|
||||
await self.server.abort_request.remote(request_id)
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_id)
|
||||
else:
|
||||
self.server.abort_request(request_id)
|
||||
self.engine.abort_request(request_id)
|
||||
|
||||
if request_id in self.request_events:
|
||||
del self.request_events[request_id]
|
||||
if request_id in self.request_outputs:
|
||||
del self.request_outputs[request_id]
|
||||
|
||||
# To prevent deadlock when a request is aborted while the server is
|
||||
# To prevent deadlock when a request is aborted while the engine is
|
||||
# running.
|
||||
if self.kicking_request_id == request_id:
|
||||
self.is_server_running = False
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id = None
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM server from the server arguments."""
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
parallel_config = server_configs[2]
|
||||
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(
|
||||
parallel_config, server_args.server_use_ray)
|
||||
# Create the LLM server.
|
||||
server = cls(server_args.worker_use_ray,
|
||||
server_args.server_use_ray,
|
||||
not server_args.disable_log_requests,
|
||||
*server_configs,
|
||||
parallel_config, engine_args.engine_use_ray)
|
||||
# Create the async LLM engine.
|
||||
engine = cls(engine_args.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
not engine_args.disable_log_requests,
|
||||
*engine_configs,
|
||||
distributed_init_method, devices,
|
||||
log_stats=not server_args.disable_log_stats)
|
||||
return server
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
|
@ -4,13 +4,13 @@ from typing import Any, List, Optional
|
|||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from cacheflow.core.scheduler import Scheduler
|
||||
from cacheflow.engine.arg_utils import EngineArgs
|
||||
from cacheflow.engine.ray_utils import DeviceID, initialize_cluster, ray
|
||||
from cacheflow.engine.tokenizer_utils import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
|
||||
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
||||
detokenize_incrementally)
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from cacheflow.utils import Counter
|
||||
from cacheflow.worker.worker import Worker
|
||||
|
@ -19,9 +19,9 @@ logger = init_logger(__name__)
|
|||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM server that receives requests and generates texts.
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
|
||||
This is the main class for the CacheFlow LLM server. It receives requests
|
||||
This is the main class for the CacheFlow LLM engine. It receives requests
|
||||
from clients and generates texts from the LLM. It includes a tokenizer, a
|
||||
language model (possibly distributed across multiple GPUs), and GPU memory
|
||||
space allocated for intermediate states (aka KV cache). This class utilizes
|
||||
|
@ -31,8 +31,8 @@ class LLMEngine:
|
|||
The `LLM` class wraps this class for offline batched inference and the
|
||||
`AsyncLLMEngine` class wraps this class for online serving.
|
||||
|
||||
NOTE: The config arguments are derived from the `ServerArgs` class. For the
|
||||
comprehensive list of arguments, see `ServerArgs`.
|
||||
NOTE: The config arguments are derived from the `EngineArgs` class. For the
|
||||
comprehensive list of arguments, see `EngineArgs`.
|
||||
|
||||
Args:
|
||||
model_config: The configuration related to the LLM model.
|
||||
|
@ -58,7 +58,7 @@ class LLMEngine:
|
|||
log_stats: bool,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Initializing an LLM server with config: "
|
||||
"Initializing an LLM engine with config: "
|
||||
f"model={model_config.model!r}, "
|
||||
f"dtype={model_config.dtype}, "
|
||||
f"use_dummy_weights={model_config.use_dummy_weights}, "
|
||||
|
@ -135,17 +135,17 @@ class LLMEngine:
|
|||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: ServerArgs) -> "LLMEngine":
|
||||
"""Creates an LLM server from the server arguments."""
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
parallel_config = server_configs[2]
|
||||
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||
# Create the LLM server.
|
||||
server = cls(*server_configs, distributed_init_method, devices,
|
||||
log_stats=not server_args.disable_log_stats)
|
||||
return server
|
||||
# Create the LLM engine.
|
||||
engine = cls(*engine_configs, distributed_init_method, devices,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
|
@ -155,10 +155,10 @@ class LLMEngine:
|
|||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Add a request to the server's request pool.
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
The request is added to the request pool and will be processed by the
|
||||
scheduler as `server.step()` is called. The exact scheduling policy is
|
||||
scheduler as `engine.step()` is called. The exact scheduling policy is
|
||||
determined by the scheduler.
|
||||
|
||||
Args:
|
||||
|
@ -211,7 +211,7 @@ class LLMEngine:
|
|||
def step(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
This function performs one decoding iteration for the server. It first
|
||||
This function performs one decoding iteration of the engine. It first
|
||||
schedules the sequences to be executed in the next iteration and the
|
||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
|
@ -13,15 +13,15 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
|
|||
|
||||
def initialize_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
server_use_ray: bool = False,
|
||||
ray_server_address: Optional[str] = None,
|
||||
engine_use_ray: bool = False,
|
||||
ray_address: Optional[str] = None,
|
||||
) -> Tuple[str, List[List[DeviceID]]]:
|
||||
"""Initialize the distributed cluster probably with Ray.
|
||||
|
||||
Args:
|
||||
parallel_config: The configurations for parallel execution.
|
||||
server_use_ray: Whether to use Ray for async server.
|
||||
ray_server_address: The address of the Ray cluster. If None, uses
|
||||
engine_use_ray: Whether to use Ray for async engine.
|
||||
ray_address: The address of the Ray cluster. If None, uses
|
||||
the default Ray cluster address.
|
||||
|
||||
Returns:
|
||||
|
@ -31,13 +31,13 @@ def initialize_cluster(
|
|||
each worker in each pipeline stage. Each device ID is a tuple of
|
||||
(rank, node resource, device id).
|
||||
"""
|
||||
if parallel_config.worker_use_ray or server_use_ray:
|
||||
if parallel_config.worker_use_ray or engine_use_ray:
|
||||
if ray is None:
|
||||
raise ImportError(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=ray_server_address)
|
||||
ray.init(address=ray_address)
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
# Initialize cluster locally.
|
|
@ -6,9 +6,9 @@ from fastapi import BackgroundTasks, FastAPI, Request
|
|||
from fastapi.responses import Response, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMEngine
|
||||
from cacheflow.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
|
@ -30,7 +30,7 @@ async def generate(request: Request) -> Response:
|
|||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
results_generator = server.generate(prompt, sampling_params, request_id)
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
|
@ -44,7 +44,7 @@ async def generate(request: Request) -> Response:
|
|||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
async def abort_request() -> None:
|
||||
await server.abort(request_id)
|
||||
await engine.abort(request_id)
|
||||
|
||||
if stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
|
@ -57,7 +57,7 @@ async def generate(request: Request) -> Response:
|
|||
async for request_output in results_generator:
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await server.abort(request_id)
|
||||
await engine.abort(request_id)
|
||||
return Response(status_code=499)
|
||||
final_output = request_output
|
||||
|
||||
|
@ -75,11 +75,11 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncServerArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMEngine.from_server_args(server_args)
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from cacheflow.engine.arg_utils import EngineArgs
|
||||
from cacheflow.engine.llm_engine import LLMEngine
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMEngine
|
||||
from cacheflow.utils import Counter
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ class LLM:
|
|||
|
||||
NOTE: This class is intended to be used for offline inference. For online
|
||||
serving, use the `AsyncLLMEngine` class instead.
|
||||
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
|
||||
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
||||
|
||||
Args:
|
||||
model: The name or path of a HuggingFace Transformers model.
|
||||
|
@ -45,20 +45,20 @@ class LLM:
|
|||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
server_args = ServerArgs(
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_server = LLMEngine.from_server_args(server_args)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_server.tokenizer
|
||||
return self.llm_engine.tokenizer
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
@ -99,7 +99,7 @@ class LLM:
|
|||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
# Add requests to the server.
|
||||
# Add requests to the engine.
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
else:
|
||||
|
@ -111,7 +111,7 @@ class LLM:
|
|||
else:
|
||||
token_ids = prompt_token_ids[i]
|
||||
self._add_request(prompt, sampling_params, token_ids)
|
||||
return self._run_server(use_tqdm)
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
|
@ -120,18 +120,18 @@ class LLM:
|
|||
prompt_token_ids: Optional[List[int]],
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||
self.llm_engine.add_request(request_id, prompt, sampling_params,
|
||||
prompt_token_ids)
|
||||
|
||||
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_server.get_num_unfinished_requests()
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
||||
# Run the server.
|
||||
# Run the engine.
|
||||
outputs: List[RequestOutput] = []
|
||||
while self.llm_server.has_unfinished_requests():
|
||||
step_outputs = self.llm_server.step()
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
step_outputs = self.llm_engine.step()
|
||||
for output in step_outputs:
|
||||
if output.finished():
|
||||
outputs.append(output)
|
||||
|
|
|
@ -10,29 +10,20 @@ import fastapi
|
|||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMEngine
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
|
||||
from cacheflow.engine.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.entrypoints.openai.protocol import (
|
||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.utils import random_uuid
|
||||
from cacheflow.entrypoints.openai.protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
LogProbs,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
|
@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
|
|||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- echo (since the cacheflow server does not currently support
|
||||
- echo (since the cacheflow engine does not currently support
|
||||
getting the logprobs of prompt tokens)
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
- logit_bias (to be supported in cacheflow server)
|
||||
- logit_bias (to be supported in cacheflow engine)
|
||||
"""
|
||||
request = CompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
|
|||
return error_check_ret
|
||||
|
||||
if request.echo:
|
||||
# We do not support echo since the cacheflow server does not
|
||||
# We do not support echo since the cacheflow engine does not
|
||||
# currently support getting the logprobs of prompt tokens.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"echo is not currently supported")
|
||||
|
@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
|
|||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None:
|
||||
# TODO: support logit_bias in cacheflow server.
|
||||
# TODO: support logit_bias in cacheflow engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
|
@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
|
|||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = server.generate(prompt, sampling_params,
|
||||
result_generator = engine.generate(prompt, sampling_params,
|
||||
request_id)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
|
@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
|
|||
not request.use_beam_search)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await server.abort(request_id)
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
text: str,
|
||||
|
@ -303,7 +294,7 @@ if __name__ == "__main__":
|
|||
help="The model name used in the API. If not specified, "
|
||||
"the model name will be the same as the "
|
||||
"huggingface name.")
|
||||
parser = AsyncServerArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.add_middleware(
|
||||
|
@ -318,8 +309,8 @@ if __name__ == "__main__":
|
|||
|
||||
served_model = args.served_model_name or args.model
|
||||
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMEngine.from_server_args(server_args)
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(args.model)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import argparse
|
||||
|
||||
from cacheflow import ServerArgs, LLMEngine, SamplingParams
|
||||
from cacheflow import EngineArgs, LLMEngine, SamplingParams
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# Parse the CLI argument and initialize the server.
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server = LLMEngine.from_server_args(server_args)
|
||||
# Parse the CLI argument and initialize the engine.
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# Test the following prompts.
|
||||
test_prompts = [
|
||||
|
@ -19,27 +19,27 @@ def main(args: argparse.Namespace):
|
|||
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
|
||||
]
|
||||
|
||||
# Run the server by calling `server.step()` manually.
|
||||
# Run the engine by calling `engine.step()` manually.
|
||||
request_id = 0
|
||||
while True:
|
||||
# To test iteration-level scheduling, we add one request at each step.
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
server.add_request(str(request_id), prompt, sampling_params)
|
||||
engine.add_request(str(request_id), prompt, sampling_params)
|
||||
request_id += 1
|
||||
|
||||
request_outputs = server.step()
|
||||
request_outputs = engine.step()
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished():
|
||||
print(request_output)
|
||||
|
||||
if not (server.has_unfinished_requests() or test_prompts):
|
||||
if not (engine.has_unfinished_requests() or test_prompts):
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Demo on using the LLMEngine class synchronously')
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
description='Demo on using the LLMEngine class directly')
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Loading…
Reference in New Issue