mirror of https://github.com/vllm-project/vllm
[Fix] Better error message when there is OOM during cache initialization (#203)
This commit is contained in:
parent
14f0b39cda
commit
1d24ccb96c
|
@ -127,6 +127,12 @@ class LLMEngine:
|
|||
# FIXME(woosuk): Change to debug log.
|
||||
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
|
||||
f'# CPU blocks: {num_cpu_blocks}')
|
||||
|
||||
if num_gpu_blocks <= 0 or num_cpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `gpu_memory_utilization` when "
|
||||
"initializing the engine.")
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ class RequestOutput:
|
|||
prompt: The prompt string of the request.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
outputs: The output sequences of the request.
|
||||
finished: Whether the whole request is finished.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue