mirror of https://github.com/vllm-project/vllm
[ci][distributed] try to fix pp test (#7054)
This commit is contained in:
parent
3bb4b1e4cd
commit
252357793d
|
@ -9,7 +9,7 @@ import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..utils import compare_two_settings
|
from ..utils import compare_two_settings, fork_new_process_for_each_test
|
||||||
|
|
||||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||||
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
|
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||||
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||||
])
|
])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
||||||
DIST_BACKEND):
|
DIST_BACKEND):
|
||||||
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
|
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
|
||||||
|
@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
"FLASHINFER",
|
"FLASHINFER",
|
||||||
])
|
])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
|
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
|
||||||
cudagraph_args = [
|
cudagraph_args = [
|
||||||
# use half precision for speed and memory savings in CI environment
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -336,3 +338,40 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||||
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
def fork_new_process_for_each_test(f):
|
||||||
|
|
||||||
|
@functools.wraps(f)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Make the process the leader of its own process group
|
||||||
|
# to avoid sending SIGTERM to the parent process
|
||||||
|
os.setpgrp()
|
||||||
|
from _pytest.outcomes import Skipped
|
||||||
|
pid = os.fork()
|
||||||
|
if pid == 0:
|
||||||
|
try:
|
||||||
|
f(*args, **kwargs)
|
||||||
|
except Skipped as e:
|
||||||
|
# convert Skipped to exit code 0
|
||||||
|
print(str(e))
|
||||||
|
os._exit(0)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
os._exit(1)
|
||||||
|
else:
|
||||||
|
os._exit(0)
|
||||||
|
else:
|
||||||
|
pgid = os.getpgid(pid)
|
||||||
|
_pid, _exitcode = os.waitpid(pid, 0)
|
||||||
|
# ignore SIGTERM signal itself
|
||||||
|
old_singla_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
||||||
|
# kill all child processes
|
||||||
|
os.killpg(pgid, signal.SIGTERM)
|
||||||
|
# restore the signal handler
|
||||||
|
signal.signal(signal.SIGTERM, old_singla_handler)
|
||||||
|
assert _exitcode == 0, (f"function {f} failed when called with"
|
||||||
|
f" args {args} and kwargs {kwargs}")
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ray.exceptions import ActorDiedError
|
from ray.exceptions import ActorDiedError # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# For older versions of Ray
|
# For older versions of Ray
|
||||||
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
|
from ray.exceptions import RayActorError as ActorDiedError # type: ignore
|
||||||
|
|
|
@ -928,7 +928,8 @@ def error_on_invalid_device_count_status():
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
# future pytorch will fix the issue, device_count will not be cached
|
# future pytorch will fix the issue, device_count will not be cached
|
||||||
# at that time, `.cache_info().currsize` will error out
|
# at that time, `.cache_info().currsize` will error out
|
||||||
cache_entries = torch.cuda.device_count.cache_info().currsize
|
cache_entries = torch.cuda.device_count.cache_info( # type: ignore
|
||||||
|
).currsize
|
||||||
if cache_entries != 0:
|
if cache_entries != 0:
|
||||||
# the function is already called, and the result is cached
|
# the function is already called, and the result is cached
|
||||||
remembered = torch.cuda.device_count()
|
remembered = torch.cuda.device_count()
|
||||||
|
|
Loading…
Reference in New Issue