[Bugfix]: Fix Tensorizer test failures (#6835)

This commit is contained in:
Sanger Steel 2024-07-26 23:02:25 -04:00 committed by GitHub
parent 55712941e5
commit 969d032265
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 47 additions and 6 deletions

View File

@ -220,7 +220,6 @@ steps:
- label: Tensorizer Test
#mirror_hardwares: [amd]
soft_fail: true
fast_check: true
commands:
- apt-get install -y curl libsodium23

View File

@ -0,0 +1,45 @@
# isort: skip_file
import contextlib
import gc
import pytest
import ray
import torch
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
def cleanup():
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
ray.shutdown()
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
return True
@pytest.fixture(autouse=True)
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield
if should_do_global_cleanup_after_test:
cleanup()
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config

View File

@ -40,7 +40,6 @@ model_ref = "facebook/opt-125m"
tensorize_model_for_testing_script = os.path.join(
os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py")
def is_curl_installed():
try:
subprocess.check_call(['curl', '--version'])
@ -63,10 +62,6 @@ def write_keyfile(keyfile_path: str):
with open(keyfile_path, 'wb') as f:
f.write(encryption_params.key)
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm")
return config
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
@ -105,6 +100,7 @@ def test_can_deserialize_s3(vllm_runner):
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path):
cleanup()
with vllm_runner(model_ref) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
@ -316,6 +312,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
cleanup()
model_ref = "facebook/opt-125m"
model_path = tmp_path / (model_ref + ".tensors")
config = TensorizerConfig(tensorizer_uri=str(model_path))