mirror of https://github.com/vllm-project/vllm
Fix various issues of async servers (#135)
This commit is contained in:
parent
8274ca23ac
commit
1a956e136b
|
@ -0,0 +1,58 @@
|
|||
import argparse
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
|
||||
for i in range(args.n_threads)]
|
||||
|
||||
headers = {"User-Agent": "CacheFlow Benchmark Client"}
|
||||
ploads = [{
|
||||
"prompt": p,
|
||||
"max_tokens": args.max_tokens,
|
||||
"temperature": 0.0,
|
||||
"ignore_eos": True,
|
||||
} for p in prompts]
|
||||
|
||||
def send_request(results, i):
|
||||
response = requests.post(args.api_url, headers=headers,
|
||||
json=ploads[i], stream=True)
|
||||
results[i] = response
|
||||
|
||||
# use args.n_threads to prompt the backend
|
||||
tik = time.time()
|
||||
threads = []
|
||||
results = [None] * args.n_threads
|
||||
for i in range(args.n_threads):
|
||||
t = threading.Thread(target=send_request, args=(results, i))
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
print(f"Time (POST): {time.time() - tik} s")
|
||||
n_words = 0
|
||||
|
||||
for i, response in enumerate(results):
|
||||
k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
|
||||
response_new_words = json.loads(k[-2].decode("utf-8"))["text"][0]
|
||||
n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
|
||||
|
||||
time_seconds = time.time() - tik
|
||||
print(f"Time (total): {time_seconds:.3f}s to finish, n_threads: {args.n_threads}, "
|
||||
f"throughput: {n_words / time_seconds} words/s.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
|
||||
parser.add_argument("--max-tokens", type=int, default=128)
|
||||
parser.add_argument("--n-threads", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
|
@ -116,15 +116,15 @@ class ParallelConfig:
|
|||
self,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
use_ray: bool,
|
||||
worker_use_ray: bool,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.use_ray = use_ray
|
||||
self.worker_use_ray = worker_use_ray
|
||||
|
||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
||||
if self.world_size > 1:
|
||||
self.use_ray = True
|
||||
self.worker_use_ray = True
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
|
|
|
@ -148,7 +148,7 @@ class BlockSpaceManager:
|
|||
# the sequences in the same group.
|
||||
blocks: Set[PhysicalTokenBlock] = set()
|
||||
for seq in seq_group.get_seqs():
|
||||
if SequenceStatus.is_finished(seq.status):
|
||||
if seq.is_finished():
|
||||
continue
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block in block_table:
|
||||
|
@ -169,7 +169,7 @@ class BlockSpaceManager:
|
|||
# CPU block -> GPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs():
|
||||
if SequenceStatus.is_finished(seq.status):
|
||||
if seq.is_finished():
|
||||
continue
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
@ -200,7 +200,7 @@ class BlockSpaceManager:
|
|||
# GPU block -> CPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs():
|
||||
if SequenceStatus.is_finished(seq.status):
|
||||
if seq.is_finished():
|
||||
continue
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
@ -231,6 +231,9 @@ class BlockSpaceManager:
|
|||
self.cpu_allocator.free(block)
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
# Already freed or haven't been scheduled yet.
|
||||
return
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
self._free_block_table(block_table)
|
||||
del self.block_tables[seq.seq_id]
|
||||
|
|
|
@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
|
|||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LOGGING_INTERVAL_SEC = 10
|
||||
_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
class PreemptionMode(enum.Enum):
|
||||
|
@ -84,6 +84,18 @@ class Scheduler:
|
|||
# Add sequence groups to the waiting queue.
|
||||
self.waiting.append(seq_group)
|
||||
|
||||
def abort_seq_group(self, request_id: str) -> None:
|
||||
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||
for seq_group in state_queue:
|
||||
if seq_group.request_id == request_id:
|
||||
# Remove the sequence group from the state queue.
|
||||
state_queue.remove(seq_group)
|
||||
for seq in seq_group.seqs:
|
||||
if seq.is_finished():
|
||||
continue
|
||||
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
|
||||
return
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return self.waiting or self.running or self.swapped
|
||||
|
||||
|
|
|
@ -7,13 +7,14 @@ import time
|
|||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import fastapi
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.logger import init_logger
|
||||
|
@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
|
|||
UsageInfo,
|
||||
)
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
|
@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
|
|||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest):
|
||||
async def create_completion(raw_request: Request):
|
||||
request = CompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
|
@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
|
|||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = server.generate(prompt, sampling_params,
|
||||
request_id=request_id)
|
||||
request_id)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use beam search.
|
||||
|
@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
|
|||
(request.best_of is None or request.n == request.best_of) and
|
||||
not request.use_beam_search)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await server.abort(request_id)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
text: str,
|
||||
logprobs: Optional[LogProbs] = None,
|
||||
|
@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
|
|||
|
||||
# Streaming response
|
||||
if stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
media_type="text/event-stream",
|
||||
background=background_tasks)
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await server.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
choices = []
|
||||
|
@ -276,7 +291,7 @@ if __name__ == "__main__":
|
|||
help="The model name used in the API. If not specified, "
|
||||
"the model name will be the same as the "
|
||||
"huggingface name.")
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
parser = AsyncServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.add_middleware(
|
||||
|
@ -291,10 +306,11 @@ if __name__ == "__main__":
|
|||
|
||||
served_model = args.served_model_name or args.model
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMServer.from_server_args(server_args)
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(args.model)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
|
|
|
@ -2,15 +2,16 @@ import argparse
|
|||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
from cacheflow.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
app = FastAPI()
|
||||
|
||||
|
@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
|
|||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
results_generator = server.generate(prompt, sampling_params)
|
||||
request_id = random_uuid()
|
||||
results_generator = server.generate(prompt, sampling_params, request_id)
|
||||
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
|
@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse:
|
|||
}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
return StreamingResponse(stream_results())
|
||||
async def abort_request() -> None:
|
||||
await server.abort(request_id)
|
||||
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(stream_results(), background=background_tasks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
parser = AsyncServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMServer.from_server_args(server_args)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
|
|
|
@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
|
|||
SWAPPED = enum.auto()
|
||||
FINISHED_STOPPED = enum.auto()
|
||||
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||
FINISHED_ABORTED = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "SequenceStatus") -> bool:
|
||||
return status in [
|
||||
SequenceStatus.FINISHED_STOPPED,
|
||||
SequenceStatus.FINISHED_LENGTH_CAPPED,
|
||||
SequenceStatus.FINISHED_ABORTED,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
|
@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
|
|||
finish_reason = "stop"
|
||||
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
|
||||
finish_reason = "length"
|
||||
elif status == SequenceStatus.FINISHED_ABORTED:
|
||||
finish_reason = "abort"
|
||||
else:
|
||||
finish_reason = None
|
||||
return finish_reason
|
||||
|
||||
|
||||
class SequenceData:
|
||||
|
||||
def __init__(
|
||||
|
@ -137,6 +142,9 @@ class Sequence:
|
|||
def get_cumulative_logprob(self) -> float:
|
||||
return self.data.cumulative_logprob
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return SequenceStatus.is_finished(self.status)
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
|
@ -182,7 +190,7 @@ class SequenceGroup:
|
|||
raise ValueError(f'Sequence {seq_id} not found.')
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceGroup(request_id={self.request_id}, "
|
||||
|
|
|
@ -15,7 +15,7 @@ class ServerArgs:
|
|||
use_dummy_weights: bool = False
|
||||
dtype: str = "default"
|
||||
seed: int = 0
|
||||
use_ray: bool = False
|
||||
worker_use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
block_size: int = 16
|
||||
|
@ -32,7 +32,63 @@ class ServerArgs:
|
|||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
return _add_server_arguments(parser)
|
||||
"""Shared CLI arguments for CacheFlow servers."""
|
||||
# Model arguments
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
parser.add_argument('--download-dir', type=str,
|
||||
default=ServerArgs.download_dir,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of '
|
||||
'huggingface')
|
||||
parser.add_argument('--use-np-weights', action='store_true',
|
||||
help='save a numpy copy of model weights for '
|
||||
'faster loading. This can increase the disk '
|
||||
'usage by up to 2x.')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true',
|
||||
help='use dummy values for model weights')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
||||
choices=['default', 'half', 'bfloat16'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "default" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray', action='store_true',
|
||||
help='use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
||||
default=ServerArgs.pipeline_parallel_size,
|
||||
help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
||||
default=ServerArgs.tensor_parallel_size,
|
||||
help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int,
|
||||
default=ServerArgs.block_size,
|
||||
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
help='token block size')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
||||
help='random seed')
|
||||
parser.add_argument('--swap-space', type=int,
|
||||
default=ServerArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--gpu-memory-utilization', type=float,
|
||||
default=ServerArgs.gpu_memory_utilization,
|
||||
help='the percentage of GPU memory to be used for'
|
||||
'the model executor')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int,
|
||||
default=ServerArgs.max_num_batched_tokens,
|
||||
help='maximum number of batched tokens per '
|
||||
'iteration')
|
||||
parser.add_argument('--max-num-seqs', type=int,
|
||||
default=ServerArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--disable-log-stats', action='store_true',
|
||||
help='disable logging statistics')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
|
||||
|
@ -53,65 +109,22 @@ class ServerArgs:
|
|||
self.swap_space)
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.use_ray)
|
||||
self.worker_use_ray)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs)
|
||||
return model_config, cache_config, parallel_config, scheduler_config
|
||||
|
||||
|
||||
def _add_server_arguments(
|
||||
parser: argparse.ArgumentParser,
|
||||
)-> argparse.ArgumentParser:
|
||||
"""Shared CLI arguments for CacheFlow servers."""
|
||||
# Model arguments
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
parser.add_argument('--download-dir', type=str,
|
||||
default=ServerArgs.download_dir,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of huggingface')
|
||||
parser.add_argument('--use-np-weights', action='store_true',
|
||||
help='save a numpy copy of model weights for faster '
|
||||
'loading. This can increase the disk usage by up '
|
||||
'to 2x.')
|
||||
parser.add_argument('--use-dummy-weights', action='store_true',
|
||||
help='use dummy values for model weights')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
|
||||
choices=['default', 'half', 'bfloat16'],
|
||||
help=('data type for model weights and activations. '
|
||||
'The "default" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.'))
|
||||
# Parallel arguments
|
||||
parser.add_argument('--use-ray', action='store_true',
|
||||
help='use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
|
||||
default=ServerArgs.pipeline_parallel_size,
|
||||
help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
|
||||
default=ServerArgs.tensor_parallel_size,
|
||||
help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
|
||||
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
|
||||
help='token block size')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
|
||||
help='random seed')
|
||||
parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--gpu-memory-utilization', type=float,
|
||||
default=ServerArgs.gpu_memory_utilization,
|
||||
help='the percentage of GPU memory to be used for the '
|
||||
'model executor')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int,
|
||||
default=ServerArgs.max_num_batched_tokens,
|
||||
help='maximum number of batched tokens per iteration')
|
||||
parser.add_argument('--max-num-seqs', type=int,
|
||||
default=ServerArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--disable-log-stats', action='store_true',
|
||||
help='disable logging statistics')
|
||||
return parser
|
||||
@dataclass
|
||||
class AsyncServerArgs(ServerArgs):
|
||||
server_use_ray: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
parser.add_argument('--server-use-ray', action='store_true',
|
||||
help='use Ray to start the LLM server in a '
|
||||
'separate process as the web server process.')
|
||||
return parser
|
||||
|
|
|
@ -2,37 +2,52 @@ import asyncio
|
|||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import ray
|
||||
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
from cacheflow.utils import random_uuid
|
||||
from cacheflow.server.ray_utils import ray, initialize_cluster
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
|
||||
|
||||
class AsyncLLMServer:
|
||||
|
||||
def __init__(self, server_use_ray: bool, *args, **kwargs) -> None:
|
||||
if server_use_ray:
|
||||
remote_server_class = ray.remote(num_cpus=0)(LLMServer)
|
||||
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
|
||||
*args, **kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.server_use_ray = server_use_ray
|
||||
if not self.server_use_ray:
|
||||
server_class = LLMServer
|
||||
elif self.worker_use_ray:
|
||||
server_class = ray.remote(num_cpus=0)(LLMServer).remote
|
||||
else:
|
||||
remote_server_class = ray.remote(num_gpus=1)(LLMServer)
|
||||
self.server = remote_server_class.remote(*args, **kwargs)
|
||||
|
||||
server_class = ray.remote(num_gpus=1)(LLMServer).remote
|
||||
self.server = server_class(*args, **kwargs)
|
||||
# Request id -> request output.
|
||||
self.request_outputs: Dict[str, RequestOutput] = {}
|
||||
# Request id -> event to notify that there is new output.
|
||||
self.request_events: Dict[str, asyncio.Event] = {}
|
||||
self.is_server_running = False
|
||||
self.kicking_request_id: Optional[str] = None
|
||||
|
||||
async def server_step(self):
|
||||
async def server_step(self, kicking_request_id: Optional[str] = None):
|
||||
self.is_server_running = True
|
||||
request_outputs = await self.server.step.remote()
|
||||
self.kicking_request_id = kicking_request_id
|
||||
if self.server_use_ray:
|
||||
request_outputs = await self.server.step.remote()
|
||||
else:
|
||||
# Yield to the event loop to allow other coroutines to run
|
||||
# while is_server_running is True. This let the server to add new
|
||||
# requests into the queue.
|
||||
await asyncio.sleep(0)
|
||||
request_outputs = self.server.step()
|
||||
self.is_server_running = False
|
||||
self.kicking_request_id = None
|
||||
|
||||
# Notify the waiting coroutines that there are new outputs ready.
|
||||
for request_output in request_outputs:
|
||||
request_id = request_output.request_id
|
||||
|
@ -40,20 +55,26 @@ class AsyncLLMServer:
|
|||
self.request_events[request_id].set()
|
||||
|
||||
async def generate(self, prompt: str, sampling_params: SamplingParams,
|
||||
request_id: Optional[str] = None) -> RequestOutput:
|
||||
request_id: str) -> RequestOutput:
|
||||
# Preprocess the request.
|
||||
arrival_time = time.time()
|
||||
|
||||
# Create an event to notify us that there is new output from the
|
||||
# cacheflow server.
|
||||
if request_id is None:
|
||||
request_id = random_uuid()
|
||||
request_event = asyncio.Event()
|
||||
self.request_events[request_id] = request_event
|
||||
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"sampling params: {sampling_params}.")
|
||||
|
||||
# Add the request into the cacheflow server's waiting queue.
|
||||
await self.server.add_request.remote(
|
||||
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||
if self.server_use_ray:
|
||||
await self.server.add_request.remote(
|
||||
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||
else:
|
||||
self.server.add_request(
|
||||
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||
|
||||
# The cacheflow server does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
|
@ -61,7 +82,7 @@ class AsyncLLMServer:
|
|||
while True:
|
||||
# Kick the server if the server is not running.
|
||||
if not self.is_server_running:
|
||||
await self.server_step()
|
||||
await self.server_step(request_id)
|
||||
|
||||
# Wait for new output. The group_event will be set in server_step
|
||||
# when there is new output available for the sequence group.
|
||||
|
@ -80,6 +101,8 @@ class AsyncLLMServer:
|
|||
|
||||
# Once finished, release the resources of the sequence group.
|
||||
if request_output.finished():
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
|
||||
del self.request_outputs[request_id]
|
||||
del self.request_events[request_id]
|
||||
# Kick the server if the server is not running. This is to
|
||||
|
@ -89,15 +112,41 @@ class AsyncLLMServer:
|
|||
await self.server_step()
|
||||
break
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
if request_id not in self.request_events:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
if self.server_use_ray:
|
||||
await self.server.abort_request.remote(request_id)
|
||||
else:
|
||||
self.server.abort_request(request_id)
|
||||
|
||||
if request_id in self.request_events:
|
||||
del self.request_events[request_id]
|
||||
if request_id in self.request_outputs:
|
||||
del self.request_outputs[request_id]
|
||||
|
||||
# To prevent deadlock when a request is aborted while the server is
|
||||
# running.
|
||||
if self.kicking_request_id == request_id:
|
||||
self.is_server_running = False
|
||||
self.kicking_request_id = None
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: ServerArgs) -> "AsyncLLMServer":
|
||||
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
parallel_config = server_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||
distributed_init_method, devices = initialize_cluster(
|
||||
parallel_config, server_args.server_use_ray)
|
||||
# Create the LLM server.
|
||||
server = cls(server_args.use_ray, *server_configs,
|
||||
server = cls(server_args.worker_use_ray,
|
||||
server_args.server_use_ray,
|
||||
*server_configs,
|
||||
distributed_init_method, devices,
|
||||
log_stats=not server_args.disable_log_stats)
|
||||
return server
|
||||
|
|
|
@ -1,11 +1,6 @@
|
|||
import time
|
||||
from typing import Any, List, Optional
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
ray = None
|
||||
|
||||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from cacheflow.core.scheduler import Scheduler
|
||||
|
@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
|
|||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
from cacheflow.server.ray_utils import ray, initialize_cluster
|
||||
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
||||
detokenize_incrementally)
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
|
@ -62,7 +57,7 @@ class LLMServer:
|
|||
assert len(stage_devices) == 1, "Only support one stage for now."
|
||||
for rank, node_resource, _ in stage_devices[0]:
|
||||
worker_cls = Worker
|
||||
if self.parallel_config.use_ray:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
worker_cls = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
|
@ -152,6 +147,9 @@ class LLMServer:
|
|||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def abort_request(self, request_id: str) -> None:
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
|
||||
|
@ -243,13 +241,13 @@ class LLMServer:
|
|||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
executor = getattr(worker, method)
|
||||
if self.parallel_config.use_ray:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = executor.remote
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
if self.parallel_config.use_ray:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
all_outputs = ray.get(all_outputs)
|
||||
|
||||
if get_all_outputs:
|
||||
|
|
|
@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
|
|||
|
||||
def initialize_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
server_use_ray: bool = False,
|
||||
address: Optional[str] = None,
|
||||
) -> Tuple[str, List[List[DeviceID]]]:
|
||||
if not parallel_config.use_ray:
|
||||
if parallel_config.worker_use_ray or server_use_ray:
|
||||
if ray is None:
|
||||
raise ImportError(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=address)
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
# Initialize cluster locally.
|
||||
port = random.randint(10000, 20000)
|
||||
# We need to setup the distributed init method to make sure
|
||||
|
@ -24,13 +33,6 @@ def initialize_cluster(
|
|||
all_stage_devices = [[(0, None, 0)]]
|
||||
return distributed_init_method, all_stage_devices
|
||||
|
||||
if ray is None:
|
||||
raise ImportError(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=address)
|
||||
|
||||
# Assume we have a uniform cluster that each node has the same number of
|
||||
# GPUs for now.
|
||||
valid_node_resources = []
|
||||
|
|
Loading…
Reference in New Issue