From 42e0c1df789a2079eeb219fa790cbf5678af662f Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 3 Jul 2023 14:50:56 -0700 Subject: [PATCH] [Quality] Add CI for formatting (#343) --- .github/workflows/pylint.yml | 31 +++++++++++++ .github/workflows/yapf.yml | 31 +++++++++++++ vllm/engine/async_llm_engine.py | 8 ++++ vllm/engine/llm_engine.py | 4 ++ vllm/entrypoints/openai/api_server.py | 67 +++++++++++++++------------ vllm/model_executor/models/bloom.py | 3 +- 6 files changed, 113 insertions(+), 31 deletions(-) create mode 100644 .github/workflows/pylint.yml create mode 100644 .github/workflows/yapf.yml diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 0000000000..5e096f3c6e --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,31 @@ +name: pylint + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + pylint: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint==2.8.2 + - name: Analysing the code with pylint + run: | + pylint vllm diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml new file mode 100644 index 0000000000..590e27597e --- /dev/null +++ b/.github/workflows/yapf.yml @@ -0,0 +1,31 @@ +name: yapf + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main +jobs: + yapf: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install yapf==0.32.0 + pip install toml==0.10.2 + - name: Running yapf + run: | + yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**' diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5576f90f85..4f3af70576 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,6 +2,7 @@ import asyncio import time from typing import Dict, List, Optional +from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.engine.ray_utils import initialize_cluster, ray @@ -206,6 +207,13 @@ class AsyncLLMEngine: self.is_engine_running = False self.kicking_request_id = None + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_model_config.remote() + else: + return self.engine.get_model_config() + @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine": diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3afdfd6ce2..65183b5889 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -210,6 +210,10 @@ class LLMEngine: """ self.scheduler.abort_seq_group(request_id) + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return self.scheduler.get_num_unfinished_seq_groups() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9670765524..75bf07e9fe 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -2,16 +2,19 @@ # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py import argparse +import asyncio from http import HTTPStatus import json import time -from typing import AsyncGenerator, Dict, List, Optional, Union, Any +from typing import AsyncGenerator, Dict, List, Optional import fastapi from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse +from fastchat.conversation import (Conversation, SeparatorStyle, + get_conv_template) import uvicorn from vllm.engine.arg_utils import AsyncEngineArgs @@ -19,11 +22,10 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, - ChatMessage, DeltaMessage, ErrorResponse, LogProbs, - ModelCard, ModelList, ModelPermission, UsageInfo) -from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, + LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams @@ -95,15 +97,15 @@ async def get_gen_prompt(request) -> str: return prompt -async def check_length(request, prompt, engine): - if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"): - context_len = engine.engine.model_config.hf_config.max_sequence_length - elif hasattr(engine.engine.model_config.hf_config, "seq_length"): - context_len = engine.engine.model_config.hf_config.seq_length - elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"): - context_len = engine.engine.model_config.hf_config.max_position_embeddings - elif hasattr(engine.engine.model_config.hf_config, "seq_length"): - context_len = engine.engine.model_config.hf_config.seq_length +async def check_length(request, prompt, model_config): + if hasattr(model_config.hf_config, "max_sequence_length"): + context_len = model_config.hf_config.max_sequence_length + elif hasattr(model_config.hf_config, "seq_length"): + context_len = model_config.hf_config.seq_length + elif hasattr(model_config.hf_config, "max_position_embeddings"): + context_len = model_config.hf_config.max_position_embeddings + elif hasattr(model_config.hf_config, "seq_length"): + context_len = model_config.hf_config.seq_length else: context_len = 2048 @@ -182,7 +184,7 @@ async def create_chat_completion(raw_request: Request): "logit_bias is not currently supported") prompt = await get_gen_prompt(request) - error_check_ret = await check_length(request, prompt, engine) + error_check_ret = await check_length(request, prompt, engine_model_config) if error_check_ret is not None: return error_check_ret @@ -206,15 +208,16 @@ async def create_chat_completion(raw_request: Request): except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - result_generator = engine.generate(prompt, sampling_params, - request_id) + result_generator = engine.generate(prompt, sampling_params, request_id) async def abort_request() -> None: await engine.abort(request_id) - def create_stream_response_json(index: int, - text: str, - finish_reason: Optional[str] = None) -> str: + def create_stream_response_json( + index: int, + text: str, + finish_reason: Optional[str] = None, + ) -> str: choice_data = ChatCompletionResponseStreamChoice( index=index, delta=DeltaMessage(content=text), @@ -238,10 +241,11 @@ async def create_chat_completion(raw_request: Request): delta=DeltaMessage(role="assistant"), finish_reason=None, ) - chunk = ChatCompletionStreamResponse( - id=request_id, choices=[choice_data], model=model_name - ) - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + chunk = ChatCompletionStreamResponse(id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" previous_texts = [""] * request.n previous_num_tokens = [0] * request.n @@ -295,8 +299,8 @@ async def create_chat_completion(raw_request: Request): choices.append(choice_data) num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = sum(len(output.token_ids) - for output in final_res.outputs) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, @@ -314,9 +318,11 @@ async def create_chat_completion(raw_request: Request): # When user requests streaming but we don't stream, we still need to # return a streaming response with a single event. response_json = response.json(ensure_ascii=False) + async def fake_stream_generator() -> AsyncGenerator[str, None]: yield f"data: {response_json}\n\n" yield "data: [DONE]\n\n" + return StreamingResponse(fake_stream_generator(), media_type="text/event-stream") @@ -367,9 +373,9 @@ async def create_completion(raw_request: Request): return create_error_response(HTTPStatus.BAD_REQUEST, "please provide at least one prompt") if len(request.prompt) > 1: - return create_error_response(HTTPStatus.BAD_REQUEST, - "multiple prompts in a batch is not " - "currently supported") + return create_error_response( + HTTPStatus.BAD_REQUEST, + "multiple prompts in a batch is not currently supported") prompt = request.prompt[0] else: prompt = request.prompt @@ -571,6 +577,7 @@ if __name__ == "__main__": engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) + engine_model_config = asyncio.run(engine.get_model_config()) # A separate tokenizer to map token IDs to strings. tokenizer = get_tokenizer(engine_args.tokenizer, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 12b17e4a4e..ffc47d01cb 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -1,5 +1,6 @@ # coding=utf-8 -# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py # Copyright 2023 The CacheFlow team. # Copyright 2022 HuggingFace Inc. team and BigScience workshop. #