Support Batch Completion in Server (#2529)

This commit is contained in:
Simon Mo 2024-01-24 17:11:07 -08:00 committed by GitHub
parent 223c19224b
commit 3a7dd7e367
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 214 additions and 104 deletions

View File

@ -1,5 +1,6 @@
import time
import os
import subprocess
import time
import sys
import pytest
@ -17,8 +18,11 @@ pytestmark = pytest.mark.asyncio
class ServerRunner:
def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
@ -58,7 +62,8 @@ def server():
"--dtype",
"bfloat16", # use half precision for speed and memory savings in CI environment
"--max-model-len",
"8192"
"8192",
"--enforce-eager",
])
ray.get(server_runner.ready.remote())
yield server_runner
@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
assert "".join(chunks) == output
async def test_batch_completions(server, client: openai.AsyncOpenAI):
# test simple list
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -1,6 +1,7 @@
import asyncio
import time
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator
from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -18,48 +19,68 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__)
TypeTokenIDs = list[int]
TypeTopLogProbs = List[Optional[dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
async def completion_stream_generator(
request: CompletionRequest,
result_generator: AsyncIterator[RequestOutput],
echo_without_generation, create_logprobs_fn, request_id, created_time,
model_name) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
has_echoed = [False] * request.n
request: CompletionRequest,
raw_request: Request,
on_abort,
result_generator: AsyncIterator[tuple[int, RequestOutput]],
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await on_abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
async for res in result_generator:
# TODO: handle client disconnect for streaming
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
token_ids = output.token_ids[previous_num_tokens[i]:]
if request.logprobs is not None:
top_logprobs = output.logprobs[previous_num_tokens[i]:]
else:
top_logprobs = None
offsets = len(previous_texts[i])
if request.echo and not has_echoed[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs + top_logprobs
else: # only just return the prompt
delta_text = res.prompt
token_ids = res.prompt_token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif request.echo and request.max_tokens > 0 and not has_echoed[i]:
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids
top_logprobs = res.prompt_logprobs + (output.logprobs or [])
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs is not None:
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
token_ids=token_ids,
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=offsets,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
@ -77,7 +98,7 @@ async def completion_stream_generator(
]).model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
if output.finish_reason is not None: # return final usage
logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
@ -129,51 +150,58 @@ def parse_prompt_format(prompt) -> tuple[bool, list]:
return prompt_is_tokens, prompts
def request_output_to_completion_response(final_res: RequestOutput, request,
echo_without_generation,
create_logprobs_fn, request_id,
created_time,
model_name) -> CompletionResponse:
assert final_res is not None
def request_output_to_completion_response(
final_res_batch: list[RequestOutput],
request: CompletionRequest,
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
) -> CompletionResponse:
choices = []
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.logprobs is not None:
if not echo_without_generation:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.echo:
token_ids = prompt_token_ids + token_ids
top_logprobs = prompt_logprobs + top_logprobs
else:
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
if not echo_without_generation:
output_text = output.text
if request.echo:
output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(
index=output.index,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
@ -189,6 +217,36 @@ def request_output_to_completion_response(final_res: RequestOutput, request,
)
def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i, iterator):
async for item in iterator:
await queue.put((i, item))
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
yield item
await asyncio.gather(*_tasks)
return consumer()
class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, served_model: str):
@ -210,9 +268,6 @@ class OpenAIServingCompletion(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation = request.echo and request.max_tokens == 0
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
@ -226,30 +281,30 @@ class OpenAIServingCompletion(OpenAIServing):
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators = []
try:
sampling_params = request.to_sampling_params()
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
if len(prompts) > 1:
raise ValueError(
"Batching in completion API is not supported.")
prompt = prompts[0]
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt)
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
input_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
result_generator = self.engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=input_ids)
generators.append(
self.engine.generate(None,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids))
except ValueError as e:
return self.create_error_response(str(e))
result_generator: AsyncIterator[tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream
@ -258,23 +313,27 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response
if stream:
return completion_stream_generator(request, result_generator,
echo_without_generation,
return completion_stream_generator(request,
raw_request,
self.engine.abort,
result_generator,
self._create_logprobs,
request_id, created_time,
model_name)
request_id,
created_time,
model_name,
num_prompts=len(prompts))
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
final_res_batch: RequestOutput = [None] * len(prompts)
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res = res
final_res_batch[i] = res
response = request_output_to_completion_response(
final_res, request, echo_without_generation, self._create_logprobs,
request_id, created_time, model_name)
final_res_batch, request, self._create_logprobs, request_id,
created_time, model_name)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.