mirror of https://github.com/vllm-project/vllm
Support chat template and `echo` for chat API (#1756)
This commit is contained in:
parent
05a38612b0
commit
66785cc05c
|
@ -107,6 +107,7 @@ OpenAI-Compatible Server
|
|||
------------------------
|
||||
|
||||
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
|
||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
||||
|
||||
Start the server:
|
||||
|
||||
|
@ -122,7 +123,13 @@ Use model from www.modelscope.cn
|
|||
$ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \
|
||||
$ --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code
|
||||
|
||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
|
||||
By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python -m vllm.entrypoints.openai.api_server \
|
||||
$ --model facebook/opt-125m \
|
||||
$ --chat-template ./examples/template_chatml.json
|
||||
|
||||
This server can be queried in the same format as OpenAI API. For example, list the models:
|
||||
|
||||
|
@ -130,6 +137,9 @@ This server can be queried in the same format as OpenAI API. For example, list t
|
|||
|
||||
$ curl http://localhost:8000/v1/models
|
||||
|
||||
Using OpenAI Completions API with vLLM
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Query the model with input prompts:
|
||||
|
||||
.. code-block:: console
|
||||
|
@ -156,3 +166,45 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep
|
|||
print("Completion result:", completion)
|
||||
|
||||
For a more detailed client example, refer to `examples/openai_completion_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_completion_client.py>`_.
|
||||
|
||||
Using OpenAI Chat API with vLLM
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
The vLLM server is designed to support the OpenAI Chat API, allowing you to engage in dynamic conversations with the model. The chat interface is a more interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations.
|
||||
|
||||
Querying the model using OpenAI Chat API:
|
||||
|
||||
You can use the `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_ endpoint to communicate with the model in a chat-like interface:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ curl http://localhost:8000/v1/chat/completions \
|
||||
$ -H "Content-Type: application/json" \
|
||||
$ -d '{
|
||||
$ "model": "facebook/opt-125m",
|
||||
$ "messages": [
|
||||
$ {"role": "system", "content": "You are a helpful assistant."},
|
||||
$ {"role": "user", "content": "Who won the world series in 2020?"}
|
||||
$ ]
|
||||
$ }'
|
||||
|
||||
Python Client Example:
|
||||
|
||||
Using the `openai` python package, you can also communicate with the model in a chat-like manner:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import openai
|
||||
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
chat_response = openai.ChatCompletion.create(
|
||||
model="facebook/opt-125m",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Tell me a joke."},
|
||||
]
|
||||
)
|
||||
print("Chat response:", chat_response)
|
||||
|
||||
For more in-depth examples and advanced features of the chat API, you can refer to the official OpenAI documentation.
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
### Instruction:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
### Response:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'user_context' %}
|
||||
### Input:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||
### Response:
|
||||
{% endif %}
|
|
@ -0,0 +1,2 @@
|
|||
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
|
@ -0,0 +1,30 @@
|
|||
<#meta#>
|
||||
- Date: {{ (messages|selectattr('role', 'equalto', 'meta-current_date')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-current_date')|list) else '' }}
|
||||
- Task: {{ (messages|selectattr('role', 'equalto', 'meta-task_name')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-task_name')|list) else '' }}
|
||||
<#system#>
|
||||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
<#chat#>
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
<#user#>
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
<#bot#>
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'user_context' %}
|
||||
<#user_context#>
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||
<#bot#>
|
||||
{% endif %}
|
|
@ -0,0 +1,119 @@
|
|||
from argparse import Namespace
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from vllm.entrypoints.openai.api_server import *
|
||||
|
||||
# Define models, templates, and their corresponding expected outputs
|
||||
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
||||
("facebook/opt-125m", None, True,
|
||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||
("facebook/opt-125m", None, False,
|
||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
||||
("facebook/opt-125m", "../../examples/template_chatml.jinja", True,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of<|im_end|>
|
||||
<|im_start|>assistant
|
||||
"""),
|
||||
("facebook/opt-125m", "../../examples/template_chatml.jinja", False,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of""")
|
||||
]
|
||||
|
||||
TEST_MESSAGES = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'Hi there!'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'What is the capital of'
|
||||
},
|
||||
]
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockTokenizer:
|
||||
chat_template = None
|
||||
|
||||
|
||||
def test_load_chat_template():
|
||||
# Testing chatml template
|
||||
template = "../../examples/template_chatml.jinja"
|
||||
mock_args = Namespace(chat_template=template)
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
# Call the function with the mocked args
|
||||
load_chat_template(mock_args, tokenizer)
|
||||
|
||||
template_content = tokenizer.chat_template
|
||||
|
||||
# Test assertions
|
||||
assert template_content is not None
|
||||
# Hard coded value for template_chatml.jinja
|
||||
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
|
||||
|
||||
|
||||
def test_no_load_chat_template():
|
||||
# Testing chatml template
|
||||
template = "../../examples/does_not_exist"
|
||||
mock_args = Namespace(chat_template=template)
|
||||
tokenizer = MockTokenizer()
|
||||
|
||||
# Call the function with the mocked args
|
||||
load_chat_template(mock_args, tokenizer=tokenizer)
|
||||
template_content = tokenizer.chat_template
|
||||
|
||||
# Test assertions
|
||||
assert template_content is not None
|
||||
# Hard coded value for template_chatml.jinja
|
||||
assert template_content == """../../examples/does_not_exist"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model,template,add_generation_prompt,expected_output",
|
||||
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
||||
async def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
expected_output):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||
|
||||
mock_args = Namespace(chat_template=template)
|
||||
load_chat_template(mock_args, tokenizer)
|
||||
|
||||
# Create a mock request object using keyword arguments
|
||||
mock_request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=TEST_MESSAGES,
|
||||
add_generation_prompt=add_generation_prompt)
|
||||
|
||||
# Call the function and get the result
|
||||
result = tokenizer.apply_chat_template(
|
||||
conversation=mock_request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=mock_request.add_generation_prompt)
|
||||
|
||||
# Test assertion
|
||||
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
|
||||
|
||||
|
||||
def test_health_endpoint():
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import codecs
|
||||
import json
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
@ -14,7 +15,6 @@ from fastapi import Request
|
|||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||
from packaging import version
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
@ -31,20 +31,55 @@ from vllm.sampling_params import SamplingParams
|
|||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
try:
|
||||
import fastchat
|
||||
from fastchat.conversation import Conversation, SeparatorStyle
|
||||
from fastchat.model.model_adapter import get_conversation_template
|
||||
_fastchat_available = True
|
||||
except ImportError:
|
||||
_fastchat_available = False
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
app = fastapi.FastAPI()
|
||||
engine = None
|
||||
response_role = None
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--host", type=str, default=None, help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.")
|
||||
parser.add_argument("--chat-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
"for the specified model")
|
||||
parser.add_argument("--response-role",
|
||||
type=str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=true`.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_error_response(status_code: HTTPStatus,
|
||||
|
@ -54,6 +89,25 @@ def create_error_response(status_code: HTTPStatus,
|
|||
status_code=status_code.value)
|
||||
|
||||
|
||||
def load_chat_template(args, tokenizer):
|
||||
if args.chat_template is not None:
|
||||
try:
|
||||
with open(args.chat_template, "r") as f:
|
||||
chat_template = f.read()
|
||||
except OSError:
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
chat_template = codecs.decode(args.chat_template, "unicode_escape")
|
||||
|
||||
tokenizer.chat_template = chat_template
|
||||
logger.info(
|
||||
f"Using supplied chat template:\n{tokenizer.chat_template}")
|
||||
elif tokenizer.chat_template is not None:
|
||||
logger.info(f"Using default chat template:\n{tokenizer.chat_template}")
|
||||
else:
|
||||
logger.warning("No chat template provided. Chat API will not work.")
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
||||
|
@ -69,53 +123,6 @@ async def check_model(request) -> Optional[JSONResponse]:
|
|||
return ret
|
||||
|
||||
|
||||
async def get_gen_prompt(request) -> str:
|
||||
if not _fastchat_available:
|
||||
raise ModuleNotFoundError(
|
||||
"fastchat is not installed. Please install fastchat to use "
|
||||
"the chat completion and conversation APIs: `$ pip install fschat`"
|
||||
)
|
||||
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
|
||||
raise ImportError(
|
||||
f"fastchat version is low. Current version: {fastchat.__version__} "
|
||||
"Please upgrade fastchat to use: `$ pip install -U fschat`")
|
||||
|
||||
conv = get_conversation_template(request.model)
|
||||
conv = Conversation(
|
||||
name=conv.name,
|
||||
system_template=conv.system_template,
|
||||
system_message=conv.system_message,
|
||||
roles=conv.roles,
|
||||
messages=list(conv.messages), # prevent in-place modification
|
||||
offset=conv.offset,
|
||||
sep_style=SeparatorStyle(conv.sep_style),
|
||||
sep=conv.sep,
|
||||
sep2=conv.sep2,
|
||||
stop_str=conv.stop_str,
|
||||
stop_token_ids=conv.stop_token_ids,
|
||||
)
|
||||
|
||||
if isinstance(request.messages, str):
|
||||
prompt = request.messages
|
||||
else:
|
||||
for message in request.messages:
|
||||
msg_role = message["role"]
|
||||
if msg_role == "system":
|
||||
conv.system_message = message["content"]
|
||||
elif msg_role == "user":
|
||||
conv.append_message(conv.roles[0], message["content"])
|
||||
elif msg_role == "assistant":
|
||||
conv.append_message(conv.roles[1], message["content"])
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def check_length(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
|
@ -207,7 +214,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||
- function_call (Users should implement this by themselves)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
@ -217,7 +223,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
prompt = await get_gen_prompt(request)
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in applying chat template from request: {str(e)}")
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
@ -225,6 +239,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
chunk_object_type = "chat.completion.chunk"
|
||||
try:
|
||||
spaces_between_special_tokens = request.spaces_between_special_tokens
|
||||
sampling_params = SamplingParams(
|
||||
|
@ -249,128 +264,162 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
finish_reason: Optional[str] = None,
|
||||
usage: Optional[UsageInfo] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=text),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
response = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[choice_data],
|
||||
)
|
||||
if usage is not None:
|
||||
response.usage = usage
|
||||
# exclude unset to leave details out of each sse
|
||||
response_json = response.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
return response_json
|
||||
def get_role() -> str:
|
||||
if request.add_generation_prompt:
|
||||
return response_role
|
||||
else:
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||
# First chunk with role
|
||||
# Send first response for each request.n (index) with the role
|
||||
role = get_role()
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
index=i, delta=DeltaMessage(role=role), finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the last message
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=last_msg_content),
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
finish_reason_sent = [False] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
completion_tokens = len(output.token_ids)
|
||||
previous_num_tokens[i] = completion_tokens
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
completion_tokens = len(output.token_ids)
|
||||
previous_num_tokens[i] = completion_tokens
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
finish_reason=output.finish_reason,
|
||||
usage=final_usage,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=[], finish_reason=output.finish_reason)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if final_usage is not None:
|
||||
chunk.usage = final_usage
|
||||
data = chunk.json(exclude_unset=True,
|
||||
exclude_none=True,
|
||||
ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
finish_reason_sent[i] = True
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def completion_full_generator():
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
|
||||
choices = []
|
||||
role = get_role()
|
||||
for output in final_res.outputs:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
choices = []
|
||||
for output in final_res.outputs:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role="assistant", content=output.text),
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(fake_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
return response
|
||||
else:
|
||||
return await completion_full_generator()
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
|
@ -642,34 +691,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--host", type=str, default=None, help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
@ -686,6 +708,8 @@ if __name__ == "__main__":
|
|||
else:
|
||||
served_model = args.model
|
||||
|
||||
response_role = args.response_role
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
|
@ -696,6 +720,7 @@ if __name__ == "__main__":
|
|||
engine_model_config.tokenizer,
|
||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||
trust_remote_code=engine_model_config.trust_remote_code)
|
||||
load_chat_template(args, tokenizer)
|
||||
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
|
|
|
@ -73,6 +73,8 @@ class ChatCompletionRequest(BaseModel):
|
|||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
echo: Optional[bool] = False
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
|
|
Loading…
Reference in New Issue