[ci][distributed] try to fix pp test (#7054)

This commit is contained in:
youkaichao 2024-08-01 22:03:12 -07:00 committed by GitHub
parent 3bb4b1e4cd
commit 252357793d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 45 additions and 3 deletions

View File

@ -9,7 +9,7 @@ import os
import pytest
from ..utils import compare_two_settings
from ..utils import compare_two_settings, fork_new_process_for_each_test
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@ -28,6 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
])
@fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
"FLASH_ATTN",
"FLASHINFER",
])
@fork_new_process_for_each_test
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
cudagraph_args = [
# use half precision for speed and memory savings in CI environment

View File

@ -1,4 +1,6 @@
import functools
import os
import signal
import subprocess
import sys
import time
@ -336,3 +338,40 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
time.sleep(5)
def fork_new_process_for_each_test(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
pid = os.fork()
if pid == 0:
try:
f(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception:
import traceback
traceback.print_exc()
os._exit(1)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_singla_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_singla_handler)
assert _exitcode == 0, (f"function {f} failed when called with"
f" args {args} and kwargs {kwargs}")
return wrapper

View File

@ -3,7 +3,7 @@ import os
from typing import List, Optional
try:
from ray.exceptions import ActorDiedError
from ray.exceptions import ActorDiedError # type: ignore
except ImportError:
# For older versions of Ray
from ray.exceptions import RayActorError as ActorDiedError # type: ignore

View File

@ -928,7 +928,8 @@ def error_on_invalid_device_count_status():
with contextlib.suppress(Exception):
# future pytorch will fix the issue, device_count will not be cached
# at that time, `.cache_info().currsize` will error out
cache_entries = torch.cuda.device_count.cache_info().currsize
cache_entries = torch.cuda.device_count.cache_info( # type: ignore
).currsize
if cache_entries != 0:
# the function is already called, and the result is cached
remembered = torch.cuda.device_count()