refactor complemention api for readability (#2499)

This commit is contained in:
Simon Mo 2024-01-18 16:45:14 -08:00 committed by GitHub
parent d2a68364c4
commit dd7e8f5f64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 286 additions and 255 deletions

View File

@ -88,6 +88,16 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
assert completion.usage == openai.types.CompletionUsage( assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11) completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
async def test_single_chat_session(server, client: openai.AsyncOpenAI): async def test_single_chat_session(server, client: openai.AsyncOpenAI):
messages = [{ messages = [{

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
@ -78,6 +79,26 @@ class ChatCompletionRequest(BaseModel):
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0 min_p: Optional[float] = 0.0
def to_sampling_params(self) -> SamplingParams:
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
)
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
@ -107,6 +128,30 @@ class CompletionRequest(BaseModel):
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0 min_p: Optional[float] = 0.0
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
)
class LogProbs(BaseModel): class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)

View File

@ -11,7 +11,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo) UsageInfo)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__) logger = init_logger(__name__)
@ -60,32 +59,11 @@ class OpenAIServingChat(OpenAIServing):
f"Error in applying chat template from request: {str(e)}") f"Error in applying chat template from request: {str(e)}")
return self.create_error_response(str(e)) return self.create_error_response(str(e))
token_ids, error_check_ret = await self._check_length(request,
prompt=prompt)
if error_check_ret is not None:
return error_check_ret
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
try: try:
spaces_between_special_tokens = request.spaces_between_special_tokens token_ids = self._validate_prompt_and_tokenize(request,
sampling_params = SamplingParams( prompt=prompt)
n=request.n, sampling_params = request.to_sampling_params()
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
top_p=request.top_p,
min_p=request.min_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))

View File

@ -1,20 +1,194 @@
import time import time
from fastapi import Request from fastapi import Request
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, AsyncIterator
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from .protocol import (CompletionRequest, CompletionResponse, from .protocol import (
CompletionResponseChoice, CompletionRequest,
CompletionResponseStreamChoice, CompletionResponse,
CompletionStreamResponse, LogProbs, UsageInfo) CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs,
UsageInfo,
)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__) logger = init_logger(__name__)
async def completion_stream_generator(
request: CompletionRequest,
result_generator: AsyncIterator[RequestOutput],
echo_without_generation, create_logprobs_fn, request_id, created_time,
model_name) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
has_echoed = [False] * request.n
async for res in result_generator:
# TODO: handle client disconnect for streaming
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
token_ids = output.token_ids[previous_num_tokens[i]:]
if request.logprobs is not None:
top_logprobs = output.logprobs[previous_num_tokens[i]:]
else:
top_logprobs = None
offsets = len(previous_texts[i])
if request.echo and not has_echoed[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs + top_logprobs
else: # only just return the prompt
delta_text = res.prompt
token_ids = res.prompt_token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
if request.logprobs is not None:
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=offsets,
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
]).json(exclude_unset=True, ensure_ascii=False)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
],
usage=final_usage,
).json(exclude_unset=True, ensure_ascii=False)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
def parse_prompt_format(prompt) -> tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError(
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
)
return prompt_is_tokens, prompts
def request_output_to_completion_response(final_res: RequestOutput, request,
echo_without_generation,
create_logprobs_fn, request_id,
created_time,
model_name) -> CompletionResponse:
assert final_res is not None
choices = []
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.logprobs is not None:
if not echo_without_generation:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.echo:
token_ids = prompt_token_ids + token_ids
top_logprobs = prompt_logprobs + top_logprobs
else:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
if not echo_without_generation:
output_text = output.text
if request.echo:
output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(
index=output.index,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
class OpenAIServingCompletion(OpenAIServing): class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, served_model: str): def __init__(self, engine: AsyncLLMEngine, served_model: str):
@ -32,7 +206,6 @@ class OpenAIServingCompletion(OpenAIServing):
suffix) suffix)
- logit_bias (to be supported by vLLM engine) - logit_bias (to be supported by vLLM engine)
""" """
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@ -40,83 +213,42 @@ class OpenAIServingCompletion(OpenAIServing):
# OpenAI API supports echoing the prompt when max_tokens is 0. # OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation = request.echo and request.max_tokens == 0 echo_without_generation = request.echo and request.max_tokens == 0
# Return error for unsupported features.
if request.suffix is not None: if request.suffix is not None:
# The language models we currently support do not support suffix.
return self.create_error_response( return self.create_error_response(
"suffix is not currently supported") "suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0: if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return self.create_error_response( return self.create_error_response(
"logit_bias is not currently supported") "logit_bias is not currently supported")
model_name = request.model model_name = request.model
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return self.create_error_response(
"please provide at least one prompt")
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt
elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return self.create_error_response(
"multiple prompts in a batch is not currently supported"
)
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt
if use_token_ids:
_, error_check_ret = await self._check_length(request,
prompt_ids=prompt)
else:
token_ids, error_check_ret = await self._check_length(
request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
created_time = int(time.monotonic()) created_time = int(time.monotonic())
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
min_p=request.min_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens
if not echo_without_generation else 1,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
prompt_logprobs=request.logprobs if request.echo else None,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
except ValueError as e:
return self.create_error_response(str(e))
if use_token_ids: # Schedule the request and get the result generator.
try:
sampling_params = request.to_sampling_params()
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
if len(prompts) > 1:
raise ValueError(
"Batching in completion API is not supported.")
prompt = prompts[0]
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
input_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
result_generator = self.engine.generate(None, result_generator = self.engine.generate(None,
sampling_params, sampling_params,
request_id, request_id,
prompt_token_ids=prompt) prompt_token_ids=input_ids)
else: except ValueError as e:
result_generator = self.engine.generate(prompt, sampling_params, return self.create_error_response(str(e))
request_id, token_ids)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use beam search.
@ -124,101 +256,13 @@ class OpenAIServingCompletion(OpenAIServing):
and (request.best_of is None or request.n == request.best_of) and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search) and not request.use_beam_search)
def create_stream_response_json(
index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
if usage is not None:
response.usage = usage
response_json = response.json(exclude_unset=True,
ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
has_echoed = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
token_ids = output.token_ids[previous_num_tokens[i]:]
if request.logprobs is not None:
top_logprobs = output.logprobs[previous_num_tokens[i]:]
else:
top_logprobs = None
offsets = len(previous_texts[i])
if request.echo and not has_echoed[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs + top_logprobs
else: # only just return the prompt
delta_text = res.prompt
token_ids = res.prompt_token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
if request.logprobs is not None:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=offsets,
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = (LogProbs()
if request.logprobs is not None else None)
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
usage=final_usage,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response # Streaming response
if stream: if stream:
return completion_stream_generator() return completion_stream_generator(request, result_generator,
echo_without_generation,
self._create_logprobs,
request_id, created_time,
model_name)
# Non-streaming response # Non-streaming response
final_res: RequestOutput = None final_res: RequestOutput = None
@ -228,62 +272,13 @@ class OpenAIServingCompletion(OpenAIServing):
await self.engine.abort(request_id) await self.engine.abort(request_id)
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res = res final_res = res
assert final_res is not None response = request_output_to_completion_response(
choices = [] final_res, request, echo_without_generation, self._create_logprobs,
prompt_token_ids = final_res.prompt_token_ids request_id, created_time, model_name)
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.logprobs is not None:
if not echo_without_generation:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.echo:
token_ids = prompt_token_ids + token_ids
top_logprobs = prompt_logprobs + top_logprobs
else:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
if not echo_without_generation:
output_text = output.text
if request.echo:
output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(
index=output.index,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream: if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False) response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]: async def fake_stream_generator() -> AsyncGenerator[str, None]:

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Union
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -104,27 +104,30 @@ class OpenAIServing:
err_type="NotFoundError", err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
async def _check_length( def _validate_prompt_and_tokenize(
self, self,
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None, prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None prompt_ids: Optional[List[int]] = None) -> List[int]:
) -> Tuple[List[int], Optional[ErrorResponse]]: if not (prompt or prompt_ids):
assert (not (prompt is None and prompt_ids is None) raise ValueError("Either prompt or prompt_ids should be provided.")
and not (prompt is not None and prompt_ids is not None) if (prompt and prompt_ids):
), "Either prompt or prompt_ids should be provided." raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids prompt).input_ids
token_num = len(input_ids) token_num = len(input_ids)
if request.max_tokens is None: if request.max_tokens is None:
request.max_tokens = self.max_model_len - token_num request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len: if token_num + request.max_tokens > self.max_model_len:
return input_ids, self.create_error_response( raise ValueError(
f"This model's maximum context length is {self.max_model_len} tokens. " f"This model's maximum context length is {self.max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens " f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, " f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). " f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", ) f"Please reduce the length of the messages or completion.", )
else: else:
return input_ids, None return input_ids

View File

@ -163,7 +163,7 @@ def prepare_hf_model_weights(
use_safetensors = True use_safetensors = True
break break
logger.info(f"Downloading model weights {allow_patterns}") logger.info(f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):