Fix various issues of async servers (#135)

This commit is contained in:
Zhuohan Li 2023-06-05 23:44:50 +08:00 committed by GitHub
parent 8274ca23ac
commit 1a956e136b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 289 additions and 121 deletions

View File

@ -0,0 +1,58 @@
import argparse
import json
import threading
import time
import requests
def main(args: argparse.Namespace):
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
for i in range(args.n_threads)]
headers = {"User-Agent": "CacheFlow Benchmark Client"}
ploads = [{
"prompt": p,
"max_tokens": args.max_tokens,
"temperature": 0.0,
"ignore_eos": True,
} for p in prompts]
def send_request(results, i):
response = requests.post(args.api_url, headers=headers,
json=ploads[i], stream=True)
results[i] = response
# use args.n_threads to prompt the backend
tik = time.time()
threads = []
results = [None] * args.n_threads
for i in range(args.n_threads):
t = threading.Thread(target=send_request, args=(results, i))
t.start()
threads.append(t)
for t in threads:
t.join()
print(f"Time (POST): {time.time() - tik} s")
n_words = 0
for i, response in enumerate(results):
k = list(response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"))
response_new_words = json.loads(k[-2].decode("utf-8"))["text"][0]
n_words += len(response_new_words.split(" ")) - len(prompts[i].split(" "))
time_seconds = time.time() - tik
print(f"Time (total): {time_seconds:.3f}s to finish, n_threads: {args.n_threads}, "
f"throughput: {n_words / time_seconds} words/s.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument("--n-threads", type=int, default=128)
args = parser.parse_args()
main(args)

View File

@ -116,15 +116,15 @@ class ParallelConfig:
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
use_ray: bool,
worker_use_ray: bool,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.use_ray = use_ray
self.worker_use_ray = worker_use_ray
self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.use_ray = True
self.worker_use_ray = True
self._verify_args()
def _verify_args(self) -> None:

View File

@ -148,7 +148,7 @@ class BlockSpaceManager:
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.get_seqs():
if SequenceStatus.is_finished(seq.status):
if seq.is_finished():
continue
block_table = self.block_tables[seq.seq_id]
for block in block_table:
@ -169,7 +169,7 @@ class BlockSpaceManager:
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs():
if SequenceStatus.is_finished(seq.status):
if seq.is_finished():
continue
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
@ -200,7 +200,7 @@ class BlockSpaceManager:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs():
if SequenceStatus.is_finished(seq.status):
if seq.is_finished():
continue
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
@ -231,6 +231,9 @@ class BlockSpaceManager:
self.cpu_allocator.free(block)
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
block_table = self.block_tables[seq.seq_id]
self._free_block_table(block_table)
del self.block_tables[seq.seq_id]

View File

@ -12,7 +12,7 @@ from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 10
_LOGGING_INTERVAL_SEC = 5
class PreemptionMode(enum.Enum):
@ -84,6 +84,18 @@ class Scheduler:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
def abort_seq_group(self, request_id: str) -> None:
for state_queue in [self.waiting, self.running, self.swapped]:
for seq_group in state_queue:
if seq_group.request_id == request_id:
# Remove the sequence group from the state queue.
state_queue.remove(seq_group)
for seq in seq_group.seqs:
if seq.is_finished():
continue
self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
return
def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped

View File

@ -7,13 +7,14 @@ import time
from typing import AsyncGenerator, Dict, List, Optional
import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger
@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
UsageInfo,
)
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
served_model = None
@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
async def create_completion(raw_request: Request):
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request)
@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = server.generate(prompt, sampling_params,
request_id=request_id)
request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
(request.best_of is None or request.n == request.best_of) and
not request.use_beam_search)
async def abort_request() -> None:
await server.abort(request_id)
def create_stream_response_json(index: int,
text: str,
logprobs: Optional[LogProbs] = None,
@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
# Streaming response
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
media_type="text/event-stream",
background=background_tasks)
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
@ -276,7 +291,7 @@ if __name__ == "__main__":
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = ServerArgs.add_cli_args(parser)
parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware(
@ -291,10 +306,11 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model
server_args = ServerArgs.from_cli_args(args)
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -2,15 +2,16 @@ import argparse
import json
from typing import AsyncGenerator
from fastapi import FastAPI, Request
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()
@ -20,7 +21,8 @@ async def generate_stream(request: Request) -> StreamingResponse:
request_dict = await request.json()
prompt = request_dict.pop("prompt")
sampling_params = SamplingParams(**request_dict)
results_generator = server.generate(prompt, sampling_params)
request_id = random_uuid()
results_generator = server.generate(prompt, sampling_params, request_id)
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
@ -35,17 +37,24 @@ async def generate_stream(request: Request) -> StreamingResponse:
}
yield (json.dumps(ret) + "\0").encode("utf-8")
return StreamingResponse(stream_results())
async def abort_request() -> None:
await server.abort(request_id)
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser = ServerArgs.add_cli_args(parser)
parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@ -12,12 +12,14 @@ class SequenceStatus(enum.Enum):
SWAPPED = enum.auto()
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
return status in [
SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED,
]
@staticmethod
@ -26,10 +28,13 @@ class SequenceStatus(enum.Enum):
finish_reason = "stop"
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
finish_reason = "length"
elif status == SequenceStatus.FINISHED_ABORTED:
finish_reason = "abort"
else:
finish_reason = None
return finish_reason
class SequenceData:
def __init__(
@ -137,6 +142,9 @@ class Sequence:
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
def fork(self, child_seq: 'Sequence') -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
@ -182,7 +190,7 @@ class SequenceGroup:
raise ValueError(f'Sequence {seq_id} not found.')
def is_finished(self) -> bool:
return all(SequenceStatus.is_finished(seq.status) for seq in self.seqs)
return all(seq.is_finished() for seq in self.seqs)
def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "

View File

@ -15,7 +15,7 @@ class ServerArgs:
use_dummy_weights: bool = False
dtype: str = "default"
seed: int = 0
use_ray: bool = False
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
block_size: int = 16
@ -32,7 +32,63 @@ class ServerArgs:
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
return _add_server_arguments(parser)
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str,
default=ServerArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument('--use-np-weights', action='store_true',
help='save a numpy copy of model weights for '
'faster loading. This can increase the disk '
'usage by up to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
choices=['default', 'half', 'bfloat16'],
help='data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
# Parallel arguments
parser.add_argument('--worker-use-ray', action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int,
default=ServerArgs.block_size,
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
help='random seed')
parser.add_argument('--swap-space', type=int,
default=ServerArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float,
default=ServerArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for'
'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics')
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
@ -53,65 +109,22 @@ class ServerArgs:
self.swap_space)
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.use_ray)
self.worker_use_ray)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config
def _add_server_arguments(
parser: argparse.ArgumentParser,
)-> argparse.ArgumentParser:
"""Shared CLI arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str,
default=ServerArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument('--use-np-weights', action='store_true',
help='save a numpy copy of model weights for faster '
'loading. This can increase the disk usage by up '
'to 2x.')
parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments
parser.add_argument('--use-ray', action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=ServerArgs.block_size,
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
help='random seed')
parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float,
default=ServerArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for the '
'model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens,
help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics')
return parser
@dataclass
class AsyncServerArgs(ServerArgs):
server_use_ray: bool = False
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
parser = ServerArgs.add_cli_args(parser)
parser.add_argument('--server-use-ray', action='store_true',
help='use Ray to start the LLM server in a '
'separate process as the web server process.')
return parser

View File

@ -2,37 +2,52 @@ import asyncio
import time
from typing import Dict, Optional
import ray
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.utils import random_uuid
from cacheflow.server.ray_utils import ray, initialize_cluster
logger = init_logger(__name__)
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMServer:
def __init__(self, server_use_ray: bool, *args, **kwargs) -> None:
if server_use_ray:
remote_server_class = ray.remote(num_cpus=0)(LLMServer)
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
*args, **kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.server_use_ray = server_use_ray
if not self.server_use_ray:
server_class = LLMServer
elif self.worker_use_ray:
server_class = ray.remote(num_cpus=0)(LLMServer).remote
else:
remote_server_class = ray.remote(num_gpus=1)(LLMServer)
self.server = remote_server_class.remote(*args, **kwargs)
server_class = ray.remote(num_gpus=1)(LLMServer).remote
self.server = server_class(*args, **kwargs)
# Request id -> request output.
self.request_outputs: Dict[str, RequestOutput] = {}
# Request id -> event to notify that there is new output.
self.request_events: Dict[str, asyncio.Event] = {}
self.is_server_running = False
self.kicking_request_id: Optional[str] = None
async def server_step(self):
async def server_step(self, kicking_request_id: Optional[str] = None):
self.is_server_running = True
request_outputs = await self.server.step.remote()
self.kicking_request_id = kicking_request_id
if self.server_use_ray:
request_outputs = await self.server.step.remote()
else:
# Yield to the event loop to allow other coroutines to run
# while is_server_running is True. This let the server to add new
# requests into the queue.
await asyncio.sleep(0)
request_outputs = self.server.step()
self.is_server_running = False
self.kicking_request_id = None
# Notify the waiting coroutines that there are new outputs ready.
for request_output in request_outputs:
request_id = request_output.request_id
@ -40,20 +55,26 @@ class AsyncLLMServer:
self.request_events[request_id].set()
async def generate(self, prompt: str, sampling_params: SamplingParams,
request_id: Optional[str] = None) -> RequestOutput:
request_id: str) -> RequestOutput:
# Preprocess the request.
arrival_time = time.time()
# Create an event to notify us that there is new output from the
# cacheflow server.
if request_id is None:
request_id = random_uuid()
request_event = asyncio.Event()
self.request_events[request_id] = request_event
logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, "
f"sampling params: {sampling_params}.")
# Add the request into the cacheflow server's waiting queue.
await self.server.add_request.remote(
request_id, prompt, sampling_params, arrival_time=arrival_time)
if self.server_use_ray:
await self.server.add_request.remote(
request_id, prompt, sampling_params, arrival_time=arrival_time)
else:
self.server.add_request(
request_id, prompt, sampling_params, arrival_time=arrival_time)
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
@ -61,7 +82,7 @@ class AsyncLLMServer:
while True:
# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step()
await self.server_step(request_id)
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
@ -80,6 +101,8 @@ class AsyncLLMServer:
# Once finished, release the resources of the sequence group.
if request_output.finished():
logger.info(f"Finished request {request_id}.")
del self.request_outputs[request_id]
del self.request_events[request_id]
# Kick the server if the server is not running. This is to
@ -89,15 +112,41 @@ class AsyncLLMServer:
await self.server_step()
break
async def abort(self, request_id: str) -> None:
if request_id not in self.request_events:
# The request has already finished or been aborted.
return
logger.info(f"Aborted request {request_id}.")
if self.server_use_ray:
await self.server.abort_request.remote(request_id)
else:
self.server.abort_request(request_id)
if request_id in self.request_events:
del self.request_events[request_id]
if request_id in self.request_outputs:
del self.request_outputs[request_id]
# To prevent deadlock when a request is aborted while the server is
# running.
if self.kicking_request_id == request_id:
self.is_server_running = False
self.kicking_request_id = None
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "AsyncLLMServer":
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
distributed_init_method, devices = initialize_cluster(
parallel_config, server_args.server_use_ray)
# Create the LLM server.
server = cls(server_args.use_ray, *server_configs,
server = cls(server_args.worker_use_ray,
server_args.server_use_ray,
*server_configs,
distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)
return server

View File

@ -1,11 +1,6 @@
import time
from typing import Any, List, Optional
try:
import ray
except ImportError:
ray = None
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.core.scheduler import Scheduler
@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.server.ray_utils import ray, initialize_cluster
from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
@ -62,7 +57,7 @@ class LLMServer:
assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]:
worker_cls = Worker
if self.parallel_config.use_ray:
if self.parallel_config.worker_use_ray:
worker_cls = ray.remote(
num_cpus=0,
num_gpus=1,
@ -152,6 +147,9 @@ class LLMServer:
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: str) -> None:
self.scheduler.abort_seq_group(request_id)
def get_num_unfinished_requests(self) -> int:
return self.scheduler.get_num_unfinished_seq_groups()
@ -243,13 +241,13 @@ class LLMServer:
all_outputs = []
for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.use_ray:
if self.parallel_config.worker_use_ray:
executor = executor.remote
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.parallel_config.use_ray:
if self.parallel_config.worker_use_ray:
all_outputs = ray.get(all_outputs)
if get_all_outputs:

View File

@ -13,9 +13,18 @@ DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), devi
def initialize_cluster(
parallel_config: ParallelConfig,
server_use_ray: bool = False,
address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]:
if not parallel_config.use_ray:
if parallel_config.worker_use_ray or server_use_ray:
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=address)
if not parallel_config.worker_use_ray:
# Initialize cluster locally.
port = random.randint(10000, 20000)
# We need to setup the distributed init method to make sure
@ -24,13 +33,6 @@ def initialize_cluster(
all_stage_devices = [[(0, None, 0)]]
return distributed_init_method, all_stage_devices
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=address)
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
valid_node_resources = []