[Bugfix] Allow "None" or "" to be passed to CLI for string args that default to None (#4586)

This commit is contained in:
Michael Goin 2024-05-03 13:32:21 -04:00 committed by GitHub
parent 3521ba4f25
commit 7e65477e5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 25 deletions

View File

@ -11,6 +11,12 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple
def nullable_str(val: str):
if not val or val == "None":
return None
return val
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
@ -96,7 +102,7 @@ class EngineArgs:
help='Name or path of the huggingface model to use.')
parser.add_argument(
'--tokenizer',
type=str,
type=nullable_str,
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use.')
parser.add_argument(
@ -105,21 +111,21 @@ class EngineArgs:
help='Skip initialization of tokenizer and detokenizer')
parser.add_argument(
'--revision',
type=str,
type=nullable_str,
default=None,
help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=str,
type=nullable_str,
default=None,
help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=str,
type=nullable_str,
default=None,
help='The specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
@ -136,7 +142,7 @@ class EngineArgs:
action='store_true',
help='Trust remote code from huggingface.')
parser.add_argument('--download-dir',
type=str,
type=nullable_str,
default=EngineArgs.download_dir,
help='Directory to download and load the weights, '
'default to the default cache dir of '
@ -187,7 +193,7 @@ class EngineArgs:
'supported for common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
type=nullable_str,
default=None,
help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when '
@ -304,7 +310,7 @@ class EngineArgs:
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
type=nullable_str,
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
@ -349,7 +355,7 @@ class EngineArgs:
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config',
type=str,
type=nullable_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
@ -404,7 +410,7 @@ class EngineArgs:
# Related to Vision-language models such as llava
parser.add_argument(
'--image-input-type',
type=str,
type=nullable_str,
default=None,
choices=[
t.name.lower() for t in VisionLanguageConfig.ImageInputType
@ -417,7 +423,7 @@ class EngineArgs:
help=('Input id for image token.'))
parser.add_argument(
'--image-input-shape',
type=str,
type=nullable_str,
default=None,
help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.'))
@ -440,7 +446,7 @@ class EngineArgs:
parser.add_argument(
'--speculative-model',
type=str,
type=nullable_str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
@ -454,7 +460,7 @@ class EngineArgs:
parser.add_argument(
'--speculative-max-model-len',
type=str,
type=int,
default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
@ -475,7 +481,7 @@ class EngineArgs:
'decoding.')
parser.add_argument('--model-loader-extra-config',
type=str,
type=nullable_str,
default=EngineArgs.model_loader_extra_config,
help='Extra config for model loader. '
'This will be passed to the model loader '

View File

@ -8,7 +8,7 @@ import argparse
import json
import ssl
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
@ -25,7 +25,10 @@ class LoRAParserAction(argparse.Action):
def make_arg_parser():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--host",
type=nullable_str,
default=None,
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--uvicorn-log-level",
@ -49,13 +52,13 @@ def make_arg_parser():
default=["*"],
help="allowed headers")
parser.add_argument("--api-key",
type=str,
type=nullable_str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument("--served-model-name",
nargs="+",
type=str,
type=nullable_str,
default=None,
help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
@ -65,33 +68,33 @@ def make_arg_parser():
"same as the `--model` argument.")
parser.add_argument(
"--lora-modules",
type=str,
type=nullable_str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument("--chat-template",
type=str,
type=nullable_str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=str,
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=str,
type=nullable_str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=str,
type=nullable_str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs",
type=str,
type=nullable_str,
default=None,
help="The CA certificates file")
parser.add_argument(
@ -102,12 +105,12 @@ def make_arg_parser():
)
parser.add_argument(
"--root-path",
type=str,
type=nullable_str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
type=nullable_str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "