mirror of https://github.com/vllm-project/vllm
[ci][distributed] try to fix pp test (#7054)
This commit is contained in:
parent
3bb4b1e4cd
commit
252357793d
|
@ -9,7 +9,7 @@ import os
|
|||
|
||||
import pytest
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
from ..utils import compare_two_settings, fork_new_process_for_each_test
|
||||
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
|
@ -28,6 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
|||
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||
])
|
||||
@fork_new_process_for_each_test
|
||||
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
||||
DIST_BACKEND):
|
||||
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
|
||||
|
@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
|||
"FLASH_ATTN",
|
||||
"FLASHINFER",
|
||||
])
|
||||
@fork_new_process_for_each_test
|
||||
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
|
||||
cudagraph_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import functools
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
@ -336,3 +338,40 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
|
|||
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def fork_new_process_for_each_test(f):
|
||||
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Make the process the leader of its own process group
|
||||
# to avoid sending SIGTERM to the parent process
|
||||
os.setpgrp()
|
||||
from _pytest.outcomes import Skipped
|
||||
pid = os.fork()
|
||||
if pid == 0:
|
||||
try:
|
||||
f(*args, **kwargs)
|
||||
except Skipped as e:
|
||||
# convert Skipped to exit code 0
|
||||
print(str(e))
|
||||
os._exit(0)
|
||||
except Exception:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
os._exit(1)
|
||||
else:
|
||||
os._exit(0)
|
||||
else:
|
||||
pgid = os.getpgid(pid)
|
||||
_pid, _exitcode = os.waitpid(pid, 0)
|
||||
# ignore SIGTERM signal itself
|
||||
old_singla_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
||||
# kill all child processes
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
# restore the signal handler
|
||||
signal.signal(signal.SIGTERM, old_singla_handler)
|
||||
assert _exitcode == 0, (f"function {f} failed when called with"
|
||||
f" args {args} and kwargs {kwargs}")
|
||||
|
||||
return wrapper
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
from typing import List, Optional
|
||||
|
||||
try:
|
||||
from ray.exceptions import ActorDiedError
|
||||
from ray.exceptions import ActorDiedError # type: ignore
|
||||
except ImportError:
|
||||
# For older versions of Ray
|
||||
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
|
||||
|
|
|
@ -928,7 +928,8 @@ def error_on_invalid_device_count_status():
|
|||
with contextlib.suppress(Exception):
|
||||
# future pytorch will fix the issue, device_count will not be cached
|
||||
# at that time, `.cache_info().currsize` will error out
|
||||
cache_entries = torch.cuda.device_count.cache_info().currsize
|
||||
cache_entries = torch.cuda.device_count.cache_info( # type: ignore
|
||||
).currsize
|
||||
if cache_entries != 0:
|
||||
# the function is already called, and the result is cached
|
||||
remembered = torch.cuda.device_count()
|
||||
|
|
Loading…
Reference in New Issue