mirror of https://github.com/vllm-project/vllm
[Quality] Add CI for formatting (#343)
This commit is contained in:
parent
e41f06702c
commit
42e0c1df78
|
@ -0,0 +1,31 @@
|
|||
name: pylint
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint==2.8.2
|
||||
- name: Analysing the code with pylint
|
||||
run: |
|
||||
pylint vllm
|
|
@ -0,0 +1,31 @@
|
|||
name: yapf
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
jobs:
|
||||
yapf:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install yapf==0.32.0
|
||||
pip install toml==0.10.2
|
||||
- name: Running yapf
|
||||
run: |
|
||||
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray
|
||||
|
@ -206,6 +207,13 @@ class AsyncLLMEngine:
|
|||
self.is_engine_running = False
|
||||
self.kicking_request_id = None
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_model_config.remote()
|
||||
else:
|
||||
return self.engine.get_model_config()
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||
|
|
|
@ -210,6 +210,10 @@ class LLMEngine:
|
|||
"""
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
"""Gets the model configuration."""
|
||||
return self.model_config
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Gets the number of unfinished requests."""
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
|
|
|
@ -2,16 +2,19 @@
|
|||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import fastapi
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastchat.conversation import (Conversation, SeparatorStyle,
|
||||
get_conv_template)
|
||||
import uvicorn
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
@ -19,11 +22,10 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse,
|
||||
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
|
||||
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
|
||||
ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
@ -95,15 +97,15 @@ async def get_gen_prompt(request) -> str:
|
|||
return prompt
|
||||
|
||||
|
||||
async def check_length(request, prompt, engine):
|
||||
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
|
||||
context_len = engine.engine.model_config.hf_config.max_sequence_length
|
||||
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
||||
context_len = engine.engine.model_config.hf_config.seq_length
|
||||
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
|
||||
context_len = engine.engine.model_config.hf_config.max_position_embeddings
|
||||
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
||||
context_len = engine.engine.model_config.hf_config.seq_length
|
||||
async def check_length(request, prompt, model_config):
|
||||
if hasattr(model_config.hf_config, "max_sequence_length"):
|
||||
context_len = model_config.hf_config.max_sequence_length
|
||||
elif hasattr(model_config.hf_config, "seq_length"):
|
||||
context_len = model_config.hf_config.seq_length
|
||||
elif hasattr(model_config.hf_config, "max_position_embeddings"):
|
||||
context_len = model_config.hf_config.max_position_embeddings
|
||||
elif hasattr(model_config.hf_config, "seq_length"):
|
||||
context_len = model_config.hf_config.seq_length
|
||||
else:
|
||||
context_len = 2048
|
||||
|
||||
|
@ -182,7 +184,7 @@ async def create_chat_completion(raw_request: Request):
|
|||
"logit_bias is not currently supported")
|
||||
|
||||
prompt = await get_gen_prompt(request)
|
||||
error_check_ret = await check_length(request, prompt, engine)
|
||||
error_check_ret = await check_length(request, prompt, engine_model_config)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
|
@ -206,15 +208,16 @@ async def create_chat_completion(raw_request: Request):
|
|||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = engine.generate(prompt, sampling_params,
|
||||
request_id)
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
text: str,
|
||||
finish_reason: Optional[str] = None) -> str:
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=text),
|
||||
|
@ -238,10 +241,11 @@ async def create_chat_completion(raw_request: Request):
|
|||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
|
@ -295,8 +299,8 @@ async def create_chat_completion(raw_request: Request):
|
|||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids)
|
||||
for output in final_res.outputs)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
|
@ -314,9 +318,11 @@ async def create_chat_completion(raw_request: Request):
|
|||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(fake_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
|
@ -367,9 +373,9 @@ async def create_completion(raw_request: Request):
|
|||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"please provide at least one prompt")
|
||||
if len(request.prompt) > 1:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"multiple prompts in a batch is not "
|
||||
"currently supported")
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"multiple prompts in a batch is not currently supported")
|
||||
prompt = request.prompt[0]
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
@ -571,6 +577,7 @@ if __name__ == "__main__":
|
|||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
|
||||
#
|
||||
|
|
Loading…
Reference in New Issue