refactor complemention api for readability (#2499)

This commit is contained in:
Simon Mo 2024-01-18 16:45:14 -08:00 committed by GitHub
parent d2a68364c4
commit dd7e8f5f64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 286 additions and 255 deletions

View File

@ -88,6 +88,16 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
messages = [{

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
class ErrorResponse(BaseModel):
@ -78,6 +79,26 @@ class ChatCompletionRequest(BaseModel):
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
def to_sampling_params(self) -> SamplingParams:
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
)
class CompletionRequest(BaseModel):
model: str
@ -107,6 +128,30 @@ class CompletionRequest(BaseModel):
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
)
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)

View File

@ -11,7 +11,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__)
@ -60,32 +59,11 @@ class OpenAIServingChat(OpenAIServing):
f"Error in applying chat template from request: {str(e)}")
return self.create_error_response(str(e))
token_ids, error_check_ret = await self._check_length(request,
prompt=prompt)
if error_check_ret is not None:
return error_check_ret
request_id = f"cmpl-{random_uuid()}"
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
top_p=request.top_p,
min_p=request.min_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
sampling_params = request.to_sampling_params()
except ValueError as e:
return self.create_error_response(str(e))

View File

@ -1,20 +1,194 @@
import time
from fastapi import Request
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, AsyncIterator
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
from .protocol import (CompletionRequest, CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse, LogProbs, UsageInfo)
from .protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs,
UsageInfo,
)
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__)
async def completion_stream_generator(
request: CompletionRequest,
result_generator: AsyncIterator[RequestOutput],
echo_without_generation, create_logprobs_fn, request_id, created_time,
model_name) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
has_echoed = [False] * request.n
async for res in result_generator:
# TODO: handle client disconnect for streaming
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
token_ids = output.token_ids[previous_num_tokens[i]:]
if request.logprobs is not None:
top_logprobs = output.logprobs[previous_num_tokens[i]:]
else:
top_logprobs = None
offsets = len(previous_texts[i])
if request.echo and not has_echoed[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs + top_logprobs
else: # only just return the prompt
delta_text = res.prompt
token_ids = res.prompt_token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
if request.logprobs is not None:
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=offsets,
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
]).json(exclude_unset=True, ensure_ascii=False)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
],
usage=final_usage,
).json(exclude_unset=True, ensure_ascii=False)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
def parse_prompt_format(prompt) -> tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError(
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
)
return prompt_is_tokens, prompts
def request_output_to_completion_response(final_res: RequestOutput, request,
echo_without_generation,
create_logprobs_fn, request_id,
created_time,
model_name) -> CompletionResponse:
assert final_res is not None
choices = []
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.logprobs is not None:
if not echo_without_generation:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.echo:
token_ids = prompt_token_ids + token_ids
top_logprobs = prompt_logprobs + top_logprobs
else:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
if not echo_without_generation:
output_text = output.text
if request.echo:
output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(
index=output.index,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, served_model: str):
@ -32,7 +206,6 @@ class OpenAIServingCompletion(OpenAIServing):
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
@ -40,83 +213,42 @@ class OpenAIServingCompletion(OpenAIServing):
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation = request.echo and request.max_tokens == 0
# Return error for unsupported features.
if request.suffix is not None:
# The language models we currently support do not support suffix.
return self.create_error_response(
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return self.create_error_response(
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return self.create_error_response(
"please provide at least one prompt")
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt
elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return self.create_error_response(
"multiple prompts in a batch is not currently supported"
)
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt
if use_token_ids:
_, error_check_ret = await self._check_length(request,
prompt_ids=prompt)
else:
token_ids, error_check_ret = await self._check_length(
request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
created_time = int(time.monotonic())
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
min_p=request.min_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens
if not echo_without_generation else 1,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
prompt_logprobs=request.logprobs if request.echo else None,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
except ValueError as e:
return self.create_error_response(str(e))
if use_token_ids:
# Schedule the request and get the result generator.
try:
sampling_params = request.to_sampling_params()
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
if len(prompts) > 1:
raise ValueError(
"Batching in completion API is not supported.")
prompt = prompts[0]
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
input_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
result_generator = self.engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=prompt)
else:
result_generator = self.engine.generate(prompt, sampling_params,
request_id, token_ids)
prompt_token_ids=input_ids)
except ValueError as e:
return self.create_error_response(str(e))
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
@ -124,101 +256,13 @@ class OpenAIServingCompletion(OpenAIServing):
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
def create_stream_response_json(
index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
if usage is not None:
response.usage = usage
response_json = response.json(exclude_unset=True,
ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
has_echoed = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
token_ids = output.token_ids[previous_num_tokens[i]:]
if request.logprobs is not None:
top_logprobs = output.logprobs[previous_num_tokens[i]:]
else:
top_logprobs = None
offsets = len(previous_texts[i])
if request.echo and not has_echoed[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs + top_logprobs
else: # only just return the prompt
delta_text = res.prompt
token_ids = res.prompt_token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
if request.logprobs is not None:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=offsets,
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = (LogProbs()
if request.logprobs is not None else None)
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
usage=final_usage,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if stream:
return completion_stream_generator()
return completion_stream_generator(request, result_generator,
echo_without_generation,
self._create_logprobs,
request_id, created_time,
model_name)
# Non-streaming response
final_res: RequestOutput = None
@ -228,62 +272,13 @@ class OpenAIServingCompletion(OpenAIServing):
await self.engine.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
assert final_res is not None
choices = []
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.logprobs is not None:
if not echo_without_generation:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.echo:
token_ids = prompt_token_ids + token_ids
top_logprobs = prompt_logprobs + top_logprobs
else:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
if not echo_without_generation:
output_text = output.text
if request.echo:
output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(
index=output.index,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
response = request_output_to_completion_response(
final_res, request, echo_without_generation, self._create_logprobs,
request_id, created_time, model_name)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:

View File

@ -1,6 +1,6 @@
import asyncio
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -104,27 +104,30 @@ class OpenAIServing:
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
async def _check_length(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[ErrorResponse]]:
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
token_num = len(input_ids)
if request.max_tokens is None:
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
return input_ids, self.create_error_response(
raise ValueError(
f"This model's maximum context length is {self.max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
else:
return input_ids, None
return input_ids

View File

@ -163,7 +163,7 @@ def prepare_hf_model_weights(
use_safetensors = True
break
logger.info(f"Downloading model weights {allow_patterns}")
logger.info(f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):