mirror of https://github.com/vllm-project/vllm
refactor complemention api for readability (#2499)
This commit is contained in:
parent
d2a68364c4
commit
dd7e8f5f64
|
@ -88,6 +88,16 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
|
||||||
assert completion.usage == openai.types.CompletionUsage(
|
assert completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||||
|
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
assert completion.choices[0].text is not None and len(
|
||||||
|
completion.choices[0].text) >= 5
|
||||||
|
|
||||||
|
|
||||||
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
||||||
messages = [{
|
messages = [{
|
||||||
|
|
|
@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
class ErrorResponse(BaseModel):
|
||||||
|
@ -78,6 +79,26 @@ class ChatCompletionRequest(BaseModel):
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: Optional[float] = 1.0
|
||||||
min_p: Optional[float] = 0.0
|
min_p: Optional[float] = 0.0
|
||||||
|
|
||||||
|
def to_sampling_params(self) -> SamplingParams:
|
||||||
|
return SamplingParams(
|
||||||
|
n=self.n,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
min_p=self.min_p,
|
||||||
|
stop=self.stop,
|
||||||
|
stop_token_ids=self.stop_token_ids,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
best_of=self.best_of,
|
||||||
|
top_k=self.top_k,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
use_beam_search=self.use_beam_search,
|
||||||
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
|
@ -107,6 +128,30 @@ class CompletionRequest(BaseModel):
|
||||||
repetition_penalty: Optional[float] = 1.0
|
repetition_penalty: Optional[float] = 1.0
|
||||||
min_p: Optional[float] = 0.0
|
min_p: Optional[float] = 0.0
|
||||||
|
|
||||||
|
def to_sampling_params(self):
|
||||||
|
echo_without_generation = self.echo and self.max_tokens == 0
|
||||||
|
|
||||||
|
return SamplingParams(
|
||||||
|
n=self.n,
|
||||||
|
best_of=self.best_of,
|
||||||
|
presence_penalty=self.presence_penalty,
|
||||||
|
frequency_penalty=self.frequency_penalty,
|
||||||
|
repetition_penalty=self.repetition_penalty,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
top_k=self.top_k,
|
||||||
|
min_p=self.min_p,
|
||||||
|
stop=self.stop,
|
||||||
|
stop_token_ids=self.stop_token_ids,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
||||||
|
logprobs=self.logprobs,
|
||||||
|
use_beam_search=self.use_beam_search,
|
||||||
|
prompt_logprobs=self.logprobs if self.echo else None,
|
||||||
|
skip_special_tokens=self.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=(self.spaces_between_special_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(BaseModel):
|
class LogProbs(BaseModel):
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
|
|
|
@ -11,7 +11,6 @@ from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
@ -60,32 +59,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||||
f"Error in applying chat template from request: {str(e)}")
|
f"Error in applying chat template from request: {str(e)}")
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
token_ids, error_check_ret = await self._check_length(request,
|
|
||||||
prompt=prompt)
|
|
||||||
if error_check_ret is not None:
|
|
||||||
return error_check_ret
|
|
||||||
|
|
||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
try:
|
try:
|
||||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
token_ids = self._validate_prompt_and_tokenize(request,
|
||||||
sampling_params = SamplingParams(
|
prompt=prompt)
|
||||||
n=request.n,
|
sampling_params = request.to_sampling_params()
|
||||||
presence_penalty=request.presence_penalty,
|
|
||||||
frequency_penalty=request.frequency_penalty,
|
|
||||||
repetition_penalty=request.repetition_penalty,
|
|
||||||
temperature=request.temperature,
|
|
||||||
top_p=request.top_p,
|
|
||||||
min_p=request.min_p,
|
|
||||||
stop=request.stop,
|
|
||||||
stop_token_ids=request.stop_token_ids,
|
|
||||||
max_tokens=request.max_tokens,
|
|
||||||
best_of=request.best_of,
|
|
||||||
top_k=request.top_k,
|
|
||||||
ignore_eos=request.ignore_eos,
|
|
||||||
use_beam_search=request.use_beam_search,
|
|
||||||
skip_special_tokens=request.skip_special_tokens,
|
|
||||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,194 @@
|
||||||
import time
|
import time
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from typing import AsyncGenerator, Optional
|
from typing import AsyncGenerator, AsyncIterator
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from .protocol import (CompletionRequest, CompletionResponse,
|
from .protocol import (
|
||||||
CompletionResponseChoice,
|
CompletionRequest,
|
||||||
CompletionResponseStreamChoice,
|
CompletionResponse,
|
||||||
CompletionStreamResponse, LogProbs, UsageInfo)
|
CompletionResponseChoice,
|
||||||
|
CompletionResponseStreamChoice,
|
||||||
|
CompletionStreamResponse,
|
||||||
|
LogProbs,
|
||||||
|
UsageInfo,
|
||||||
|
)
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def completion_stream_generator(
|
||||||
|
request: CompletionRequest,
|
||||||
|
result_generator: AsyncIterator[RequestOutput],
|
||||||
|
echo_without_generation, create_logprobs_fn, request_id, created_time,
|
||||||
|
model_name) -> AsyncGenerator[str, None]:
|
||||||
|
previous_texts = [""] * request.n
|
||||||
|
previous_num_tokens = [0] * request.n
|
||||||
|
has_echoed = [False] * request.n
|
||||||
|
|
||||||
|
async for res in result_generator:
|
||||||
|
# TODO: handle client disconnect for streaming
|
||||||
|
for output in res.outputs:
|
||||||
|
i = output.index
|
||||||
|
delta_text = output.text[len(previous_texts[i]):]
|
||||||
|
token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||||
|
if request.logprobs is not None:
|
||||||
|
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
||||||
|
else:
|
||||||
|
top_logprobs = None
|
||||||
|
offsets = len(previous_texts[i])
|
||||||
|
if request.echo and not has_echoed[i]:
|
||||||
|
if not echo_without_generation:
|
||||||
|
delta_text = res.prompt + delta_text
|
||||||
|
token_ids = res.prompt_token_ids + token_ids
|
||||||
|
if top_logprobs:
|
||||||
|
top_logprobs = res.prompt_logprobs + top_logprobs
|
||||||
|
else: # only just return the prompt
|
||||||
|
delta_text = res.prompt
|
||||||
|
token_ids = res.prompt_token_ids
|
||||||
|
if top_logprobs:
|
||||||
|
top_logprobs = res.prompt_logprobs
|
||||||
|
has_echoed[i] = True
|
||||||
|
if request.logprobs is not None:
|
||||||
|
logprobs = create_logprobs_fn(
|
||||||
|
token_ids=token_ids,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
num_output_top_logprobs=request.logprobs,
|
||||||
|
initial_text_offset=offsets,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
previous_texts[i] = output.text
|
||||||
|
previous_num_tokens[i] = len(output.token_ids)
|
||||||
|
finish_reason = output.finish_reason
|
||||||
|
response_json = CompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=[
|
||||||
|
CompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
text=delta_text,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
]).json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
yield f"data: {response_json}\n\n"
|
||||||
|
|
||||||
|
if output.finish_reason is not None:
|
||||||
|
logprobs = LogProbs() if request.logprobs is not None else None
|
||||||
|
prompt_tokens = len(res.prompt_token_ids)
|
||||||
|
completion_tokens = len(output.token_ids)
|
||||||
|
final_usage = UsageInfo(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
response_json = CompletionStreamResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=[
|
||||||
|
CompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
text="",
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=final_usage,
|
||||||
|
).json(exclude_unset=True, ensure_ascii=False)
|
||||||
|
yield f"data: {response_json}\n\n"
|
||||||
|
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_prompt_format(prompt) -> tuple[bool, list]:
|
||||||
|
# get the prompt, openai supports the following
|
||||||
|
# "a string, array of strings, array of tokens, or array of token arrays."
|
||||||
|
prompt_is_tokens = False
|
||||||
|
prompts = [prompt] # case 1: a string
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
if len(prompt) == 0:
|
||||||
|
raise ValueError("please provide at least one prompt")
|
||||||
|
elif isinstance(prompt[0], str):
|
||||||
|
prompt_is_tokens = False
|
||||||
|
prompts = prompt # case 2: array of strings
|
||||||
|
elif isinstance(prompt[0], int):
|
||||||
|
prompt_is_tokens = True
|
||||||
|
prompts = [prompt] # case 3: array of tokens
|
||||||
|
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
|
||||||
|
prompt_is_tokens = True
|
||||||
|
prompts = prompt # case 4: array of token arrays
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
|
||||||
|
)
|
||||||
|
return prompt_is_tokens, prompts
|
||||||
|
|
||||||
|
|
||||||
|
def request_output_to_completion_response(final_res: RequestOutput, request,
|
||||||
|
echo_without_generation,
|
||||||
|
create_logprobs_fn, request_id,
|
||||||
|
created_time,
|
||||||
|
model_name) -> CompletionResponse:
|
||||||
|
assert final_res is not None
|
||||||
|
choices = []
|
||||||
|
prompt_token_ids = final_res.prompt_token_ids
|
||||||
|
prompt_logprobs = final_res.prompt_logprobs
|
||||||
|
prompt_text = final_res.prompt
|
||||||
|
for output in final_res.outputs:
|
||||||
|
if request.logprobs is not None:
|
||||||
|
if not echo_without_generation:
|
||||||
|
token_ids = output.token_ids
|
||||||
|
top_logprobs = output.logprobs
|
||||||
|
if request.echo:
|
||||||
|
token_ids = prompt_token_ids + token_ids
|
||||||
|
top_logprobs = prompt_logprobs + top_logprobs
|
||||||
|
else:
|
||||||
|
token_ids = prompt_token_ids
|
||||||
|
top_logprobs = prompt_logprobs
|
||||||
|
logprobs = create_logprobs_fn(
|
||||||
|
token_ids=token_ids,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
num_output_top_logprobs=request.logprobs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
if not echo_without_generation:
|
||||||
|
output_text = output.text
|
||||||
|
if request.echo:
|
||||||
|
output_text = prompt_text + output_text
|
||||||
|
else:
|
||||||
|
output_text = prompt_text
|
||||||
|
choice_data = CompletionResponseChoice(
|
||||||
|
index=output.index,
|
||||||
|
text=output_text,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=output.finish_reason,
|
||||||
|
)
|
||||||
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||||
|
num_generated_tokens = sum(
|
||||||
|
len(output.token_ids) for output in final_res.outputs)
|
||||||
|
usage = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_generated_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CompletionResponse(
|
||||||
|
id=request_id,
|
||||||
|
created=created_time,
|
||||||
|
model=model_name,
|
||||||
|
choices=choices,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingCompletion(OpenAIServing):
|
class OpenAIServingCompletion(OpenAIServing):
|
||||||
|
|
||||||
def __init__(self, engine: AsyncLLMEngine, served_model: str):
|
def __init__(self, engine: AsyncLLMEngine, served_model: str):
|
||||||
|
@ -32,7 +206,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||||
suffix)
|
suffix)
|
||||||
- logit_bias (to be supported by vLLM engine)
|
- logit_bias (to be supported by vLLM engine)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
error_check_ret = await self._check_model(request)
|
error_check_ret = await self._check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
@ -40,83 +213,42 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||||
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
||||||
echo_without_generation = request.echo and request.max_tokens == 0
|
echo_without_generation = request.echo and request.max_tokens == 0
|
||||||
|
|
||||||
|
# Return error for unsupported features.
|
||||||
if request.suffix is not None:
|
if request.suffix is not None:
|
||||||
# The language models we currently support do not support suffix.
|
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"suffix is not currently supported")
|
"suffix is not currently supported")
|
||||||
|
|
||||||
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
||||||
# TODO: support logit_bias in vLLM engine.
|
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
model_name = request.model
|
model_name = request.model
|
||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
|
|
||||||
use_token_ids = False
|
|
||||||
if isinstance(request.prompt, list):
|
|
||||||
if len(request.prompt) == 0:
|
|
||||||
return self.create_error_response(
|
|
||||||
"please provide at least one prompt")
|
|
||||||
first_element = request.prompt[0]
|
|
||||||
if isinstance(first_element, int):
|
|
||||||
use_token_ids = True
|
|
||||||
prompt = request.prompt
|
|
||||||
elif isinstance(first_element, (str, list)):
|
|
||||||
# TODO: handles multiple prompt case in list[list[int]]
|
|
||||||
if len(request.prompt) > 1:
|
|
||||||
return self.create_error_response(
|
|
||||||
"multiple prompts in a batch is not currently supported"
|
|
||||||
)
|
|
||||||
use_token_ids = not isinstance(first_element, str)
|
|
||||||
prompt = request.prompt[0]
|
|
||||||
else:
|
|
||||||
prompt = request.prompt
|
|
||||||
|
|
||||||
if use_token_ids:
|
|
||||||
_, error_check_ret = await self._check_length(request,
|
|
||||||
prompt_ids=prompt)
|
|
||||||
else:
|
|
||||||
token_ids, error_check_ret = await self._check_length(
|
|
||||||
request, prompt=prompt)
|
|
||||||
if error_check_ret is not None:
|
|
||||||
return error_check_ret
|
|
||||||
|
|
||||||
created_time = int(time.monotonic())
|
created_time = int(time.monotonic())
|
||||||
try:
|
|
||||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=request.n,
|
|
||||||
best_of=request.best_of,
|
|
||||||
presence_penalty=request.presence_penalty,
|
|
||||||
frequency_penalty=request.frequency_penalty,
|
|
||||||
repetition_penalty=request.repetition_penalty,
|
|
||||||
temperature=request.temperature,
|
|
||||||
top_p=request.top_p,
|
|
||||||
top_k=request.top_k,
|
|
||||||
min_p=request.min_p,
|
|
||||||
stop=request.stop,
|
|
||||||
stop_token_ids=request.stop_token_ids,
|
|
||||||
ignore_eos=request.ignore_eos,
|
|
||||||
max_tokens=request.max_tokens
|
|
||||||
if not echo_without_generation else 1,
|
|
||||||
logprobs=request.logprobs,
|
|
||||||
use_beam_search=request.use_beam_search,
|
|
||||||
prompt_logprobs=request.logprobs if request.echo else None,
|
|
||||||
skip_special_tokens=request.skip_special_tokens,
|
|
||||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
if use_token_ids:
|
# Schedule the request and get the result generator.
|
||||||
|
try:
|
||||||
|
sampling_params = request.to_sampling_params()
|
||||||
|
|
||||||
|
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
||||||
|
|
||||||
|
if len(prompts) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Batching in completion API is not supported.")
|
||||||
|
prompt = prompts[0]
|
||||||
|
|
||||||
|
if prompt_is_tokens:
|
||||||
|
input_ids = self._validate_prompt_and_tokenize(
|
||||||
|
request, prompt_ids=prompt)
|
||||||
|
else:
|
||||||
|
input_ids = self._validate_prompt_and_tokenize(request,
|
||||||
|
prompt=prompt)
|
||||||
|
|
||||||
result_generator = self.engine.generate(None,
|
result_generator = self.engine.generate(None,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
request_id,
|
request_id,
|
||||||
prompt_token_ids=prompt)
|
prompt_token_ids=input_ids)
|
||||||
else:
|
except ValueError as e:
|
||||||
result_generator = self.engine.generate(prompt, sampling_params,
|
return self.create_error_response(str(e))
|
||||||
request_id, token_ids)
|
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
# results. In addition, we do not stream the results when use beam search.
|
# results. In addition, we do not stream the results when use beam search.
|
||||||
|
@ -124,101 +256,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||||
and (request.best_of is None or request.n == request.best_of)
|
and (request.best_of is None or request.n == request.best_of)
|
||||||
and not request.use_beam_search)
|
and not request.use_beam_search)
|
||||||
|
|
||||||
def create_stream_response_json(
|
|
||||||
index: int,
|
|
||||||
text: str,
|
|
||||||
logprobs: Optional[LogProbs] = None,
|
|
||||||
finish_reason: Optional[str] = None,
|
|
||||||
usage: Optional[UsageInfo] = None,
|
|
||||||
) -> str:
|
|
||||||
choice_data = CompletionResponseStreamChoice(
|
|
||||||
index=index,
|
|
||||||
text=text,
|
|
||||||
logprobs=logprobs,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
)
|
|
||||||
response = CompletionStreamResponse(
|
|
||||||
id=request_id,
|
|
||||||
created=created_time,
|
|
||||||
model=model_name,
|
|
||||||
choices=[choice_data],
|
|
||||||
)
|
|
||||||
if usage is not None:
|
|
||||||
response.usage = usage
|
|
||||||
response_json = response.json(exclude_unset=True,
|
|
||||||
ensure_ascii=False)
|
|
||||||
|
|
||||||
return response_json
|
|
||||||
|
|
||||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
|
||||||
previous_texts = [""] * request.n
|
|
||||||
previous_num_tokens = [0] * request.n
|
|
||||||
has_echoed = [False] * request.n
|
|
||||||
async for res in result_generator:
|
|
||||||
res: RequestOutput
|
|
||||||
for output in res.outputs:
|
|
||||||
i = output.index
|
|
||||||
delta_text = output.text[len(previous_texts[i]):]
|
|
||||||
token_ids = output.token_ids[previous_num_tokens[i]:]
|
|
||||||
if request.logprobs is not None:
|
|
||||||
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
|
||||||
else:
|
|
||||||
top_logprobs = None
|
|
||||||
offsets = len(previous_texts[i])
|
|
||||||
if request.echo and not has_echoed[i]:
|
|
||||||
if not echo_without_generation:
|
|
||||||
delta_text = res.prompt + delta_text
|
|
||||||
token_ids = res.prompt_token_ids + token_ids
|
|
||||||
if top_logprobs:
|
|
||||||
top_logprobs = res.prompt_logprobs + top_logprobs
|
|
||||||
else: # only just return the prompt
|
|
||||||
delta_text = res.prompt
|
|
||||||
token_ids = res.prompt_token_ids
|
|
||||||
if top_logprobs:
|
|
||||||
top_logprobs = res.prompt_logprobs
|
|
||||||
has_echoed[i] = True
|
|
||||||
if request.logprobs is not None:
|
|
||||||
logprobs = self._create_logprobs(
|
|
||||||
token_ids=token_ids,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
num_output_top_logprobs=request.logprobs,
|
|
||||||
initial_text_offset=offsets,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logprobs = None
|
|
||||||
previous_texts[i] = output.text
|
|
||||||
previous_num_tokens[i] = len(output.token_ids)
|
|
||||||
finish_reason = output.finish_reason
|
|
||||||
response_json = create_stream_response_json(
|
|
||||||
index=i,
|
|
||||||
text=delta_text,
|
|
||||||
logprobs=logprobs,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
)
|
|
||||||
yield f"data: {response_json}\n\n"
|
|
||||||
if output.finish_reason is not None:
|
|
||||||
logprobs = (LogProbs()
|
|
||||||
if request.logprobs is not None else None)
|
|
||||||
prompt_tokens = len(res.prompt_token_ids)
|
|
||||||
completion_tokens = len(output.token_ids)
|
|
||||||
final_usage = UsageInfo(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
response_json = create_stream_response_json(
|
|
||||||
index=i,
|
|
||||||
text="",
|
|
||||||
logprobs=logprobs,
|
|
||||||
finish_reason=output.finish_reason,
|
|
||||||
usage=final_usage,
|
|
||||||
)
|
|
||||||
yield f"data: {response_json}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if stream:
|
if stream:
|
||||||
return completion_stream_generator()
|
return completion_stream_generator(request, result_generator,
|
||||||
|
echo_without_generation,
|
||||||
|
self._create_logprobs,
|
||||||
|
request_id, created_time,
|
||||||
|
model_name)
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res: RequestOutput = None
|
final_res: RequestOutput = None
|
||||||
|
@ -228,62 +272,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||||
await self.engine.abort(request_id)
|
await self.engine.abort(request_id)
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
final_res = res
|
final_res = res
|
||||||
assert final_res is not None
|
response = request_output_to_completion_response(
|
||||||
choices = []
|
final_res, request, echo_without_generation, self._create_logprobs,
|
||||||
prompt_token_ids = final_res.prompt_token_ids
|
request_id, created_time, model_name)
|
||||||
prompt_logprobs = final_res.prompt_logprobs
|
|
||||||
prompt_text = final_res.prompt
|
|
||||||
for output in final_res.outputs:
|
|
||||||
if request.logprobs is not None:
|
|
||||||
if not echo_without_generation:
|
|
||||||
token_ids = output.token_ids
|
|
||||||
top_logprobs = output.logprobs
|
|
||||||
if request.echo:
|
|
||||||
token_ids = prompt_token_ids + token_ids
|
|
||||||
top_logprobs = prompt_logprobs + top_logprobs
|
|
||||||
else:
|
|
||||||
token_ids = prompt_token_ids
|
|
||||||
top_logprobs = prompt_logprobs
|
|
||||||
logprobs = self._create_logprobs(
|
|
||||||
token_ids=token_ids,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
num_output_top_logprobs=request.logprobs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logprobs = None
|
|
||||||
if not echo_without_generation:
|
|
||||||
output_text = output.text
|
|
||||||
if request.echo:
|
|
||||||
output_text = prompt_text + output_text
|
|
||||||
else:
|
|
||||||
output_text = prompt_text
|
|
||||||
choice_data = CompletionResponseChoice(
|
|
||||||
index=output.index,
|
|
||||||
text=output_text,
|
|
||||||
logprobs=logprobs,
|
|
||||||
finish_reason=output.finish_reason,
|
|
||||||
)
|
|
||||||
choices.append(choice_data)
|
|
||||||
|
|
||||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
|
||||||
num_generated_tokens = sum(
|
|
||||||
len(output.token_ids) for output in final_res.outputs)
|
|
||||||
usage = UsageInfo(
|
|
||||||
prompt_tokens=num_prompt_tokens,
|
|
||||||
completion_tokens=num_generated_tokens,
|
|
||||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
|
||||||
)
|
|
||||||
response = CompletionResponse(
|
|
||||||
id=request_id,
|
|
||||||
created=created_time,
|
|
||||||
model=model_name,
|
|
||||||
choices=choices,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# When user requests streaming but we don't stream, we still need to
|
||||||
|
# return a streaming response with a single event.
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# When user requests streaming but we don't stream, we still need to
|
|
||||||
# return a streaming response with a single event.
|
|
||||||
response_json = response.json(ensure_ascii=False)
|
response_json = response.json(ensure_ascii=False)
|
||||||
|
|
||||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Union
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
@ -104,27 +104,30 @@ class OpenAIServing:
|
||||||
err_type="NotFoundError",
|
err_type="NotFoundError",
|
||||||
status_code=HTTPStatus.NOT_FOUND)
|
status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
async def _check_length(
|
def _validate_prompt_and_tokenize(
|
||||||
self,
|
self,
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
prompt_ids: Optional[List[int]] = None
|
prompt_ids: Optional[List[int]] = None) -> List[int]:
|
||||||
) -> Tuple[List[int], Optional[ErrorResponse]]:
|
if not (prompt or prompt_ids):
|
||||||
assert (not (prompt is None and prompt_ids is None)
|
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||||
and not (prompt is not None and prompt_ids is not None)
|
if (prompt and prompt_ids):
|
||||||
), "Either prompt or prompt_ids should be provided."
|
raise ValueError(
|
||||||
|
"Only one of prompt or prompt_ids should be provided.")
|
||||||
|
|
||||||
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
|
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
|
||||||
prompt).input_ids
|
prompt).input_ids
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
if request.max_tokens is None:
|
if request.max_tokens is None:
|
||||||
request.max_tokens = self.max_model_len - token_num
|
request.max_tokens = self.max_model_len - token_num
|
||||||
|
|
||||||
if token_num + request.max_tokens > self.max_model_len:
|
if token_num + request.max_tokens > self.max_model_len:
|
||||||
return input_ids, self.create_error_response(
|
raise ValueError(
|
||||||
f"This model's maximum context length is {self.max_model_len} tokens. "
|
f"This model's maximum context length is {self.max_model_len} tokens. "
|
||||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||||
f"({token_num} in the messages, "
|
f"({token_num} in the messages, "
|
||||||
f"{request.max_tokens} in the completion). "
|
f"{request.max_tokens} in the completion). "
|
||||||
f"Please reduce the length of the messages or completion.", )
|
f"Please reduce the length of the messages or completion.", )
|
||||||
else:
|
else:
|
||||||
return input_ids, None
|
return input_ids
|
||||||
|
|
|
@ -163,7 +163,7 @@ def prepare_hf_model_weights(
|
||||||
use_safetensors = True
|
use_safetensors = True
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"Downloading model weights {allow_patterns}")
|
logger.info(f"Using model weights format {allow_patterns}")
|
||||||
# Use file lock to prevent multiple processes from
|
# Use file lock to prevent multiple processes from
|
||||||
# downloading the same model weights at the same time.
|
# downloading the same model weights at the same time.
|
||||||
with get_lock(model_name_or_path, cache_dir):
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
|
|
Loading…
Reference in New Issue