From cafb8e06c5ffa359ac7fa4b53795e6eaa1a200c7 Mon Sep 17 00:00:00 2001 From: Yuan Date: Tue, 4 Jun 2024 01:39:50 +0800 Subject: [PATCH] [CI/BUILD] enable intel queue for longer CPU tests (#4113) --- .buildkite/run-cpu-test.sh | 14 +++- .buildkite/test-template.j2 | 2 + Dockerfile.cpu | 6 +- csrc/cpu/pos_encoding.cpp | 105 ++++++++++++++-------------- tests/conftest.py | 36 ++++++---- tests/models/test_aqlm.py | 11 +-- tests/models/test_big_models.py | 10 ++- tests/models/test_fp8.py | 11 +-- tests/models/test_gptq_marlin.py | 11 +-- tests/models/test_gptq_marlin_24.py | 11 +-- tests/models/test_marlin.py | 11 +-- 11 files changed, 138 insertions(+), 90 deletions(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 414045fe16..d1200ee84d 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; } trap remove_docker_container EXIT remove_docker_container -# Run the image and launch offline inference -docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py +# Run the image +docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test + +# offline inference +docker exec cpu-test bash -c "python3 examples/offline_inference.py" + +# Run basic model test +docker exec cpu-test bash -c "cd tests; + pip install pytest Pillow protobuf + bash ../.buildkite/download-images.sh + cd ../ + pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py" diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 265833e2cc..7e986c9884 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -40,6 +40,8 @@ steps: - label: "Intel Test" depends_on: ~ + agents: + queue: intel command: bash .buildkite/run-cpu-test.sh {% for step in steps %} diff --git a/Dockerfile.cpu b/Dockerfile.cpu index aec7982421..ae23e27b41 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -1,6 +1,6 @@ # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. -FROM ubuntu:22.04 +FROM ubuntu:22.04 AS cpu-test-1 RUN apt-get update -y \ && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ @@ -9,6 +9,8 @@ RUN apt-get update -y \ RUN pip install --upgrade pip \ && pip install wheel packaging ninja setuptools>=49.4.0 numpy +FROM cpu-test-1 AS build + COPY ./ /workspace/vllm WORKDIR /workspace/vllm @@ -19,4 +21,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install WORKDIR /workspace/ +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + CMD ["/bin/bash"] diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 73bf77e46f..e8aead17ae 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -21,7 +21,57 @@ void rotary_embedding_impl( constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); const int embed_dim = rot_dim / 2; - TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); + bool flag = (embed_dim % VEC_ELEM_NUM == 0); + const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM; + + auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr, + scalar_t* qk) { + int j = 0; + for (; j < loop_upper; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t q_x(qk + out_x); + const scalar_vec_t q_y(qk + out_y); + + vec_op::FP32Vec8 fp32_cos(cos); + vec_op::FP32Vec8 fp32_sin(sin); + + vec_op::FP32Vec8 fp32_q_x(q_x); + vec_op::FP32Vec8 fp32_q_y(q_y); + + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1).save(qk + out_x); + + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2).save(qk + out_y); + } + if (!flag) { + for (; j < embed_dim; ++j) { + const int x_index = j; + const int y_index = embed_dim + j; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const float fp32_cos = cache_ptr[x_index]; + const float fp32_sin = cache_ptr[y_index]; + + const float fp32_q_x = qk[out_x]; + const float fp32_q_y = qk[out_y]; + + qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + } + } + }; #pragma omp parallel for for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { @@ -32,62 +82,13 @@ void rotary_embedding_impl( const int head_idx = i; const int64_t token_head = token_idx * query_stride + head_idx * head_size; - for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { - const int rot_offset = j; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; - - const int64_t out_x = token_head + x_index; - const int64_t out_y = token_head + y_index; - - const scalar_vec_t cos(cache_ptr + x_index); - const scalar_vec_t sin(cache_ptr + y_index); - - const scalar_vec_t q_x(query + out_x); - const scalar_vec_t q_y(query + out_y); - - vec_op::FP32Vec8 fp32_cos(cos); - vec_op::FP32Vec8 fp32_sin(sin); - - vec_op::FP32Vec8 fp32_q_x(q_x); - vec_op::FP32Vec8 fp32_q_y(q_y); - - auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; - scalar_vec_t(out1).save(query + out_x); - - auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; - scalar_vec_t(out2).save(query + out_y); - } + compute_loop(token_head, cache_ptr, query); } for (int i = 0; i < num_kv_heads; ++i) { const int head_idx = i; const int64_t token_head = token_idx * key_stride + head_idx * head_size; - for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { - const int rot_offset = j; - const int x_index = rot_offset; - const int y_index = embed_dim + rot_offset; - - const int64_t out_x = token_head + x_index; - const int64_t out_y = token_head + y_index; - - const scalar_vec_t cos(cache_ptr + x_index); - const scalar_vec_t sin(cache_ptr + y_index); - - const scalar_vec_t k_x(key + out_x); - const scalar_vec_t k_y(key + out_y); - - vec_op::FP32Vec8 fp32_cos(cos); - vec_op::FP32Vec8 fp32_sin(sin); - - vec_op::FP32Vec8 fp32_k_x(k_x); - vec_op::FP32Vec8 fp32_k_y(k_y); - - auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; - scalar_vec_t(out1).save(key + out_x); - auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; - scalar_vec_t(out2).save(key + out_y); - } + compute_loop(token_head, cache_ptr, key); } } } diff --git a/tests/conftest.py b/tests/conftest.py index e749338e10..764374a779 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ from vllm.logger import init_logger from vllm.multimodal import MultiModalData from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.sequence import SampleLogprobs +from vllm.utils import is_cpu logger = init_logger(__name__) @@ -58,7 +59,8 @@ def cleanup(): with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() gc.collect() - torch.cuda.empty_cache() + if not is_cpu(): + torch.cuda.empty_cache() @pytest.fixture() @@ -151,6 +153,12 @@ _EMBEDDING_MODELS = [ class HfRunner: + def wrap_device(self, input: any): + if not is_cpu(): + return input.to("cuda") + else: + return input.to("cpu") + def __init__( self, model_name: str, @@ -164,16 +172,18 @@ class HfRunner: if model_name in _EMBEDDING_MODELS: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer - self.model = SentenceTransformer( - model_name, - device="cpu", - ).to(dtype=torch_dtype).cuda() + self.model = self.wrap_device( + SentenceTransformer( + model_name, + device="cpu", + ).to(dtype=torch_dtype)) else: - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() + self.model = self.wrap_device( + AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + )) self.tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -214,7 +224,7 @@ class HfRunner: inputs = self.processor(**processor_kwargs) output_ids = self.model.generate( - **inputs.to("cuda"), + **self.wrap_device(inputs), use_cache=True, **kwargs, ) @@ -271,7 +281,7 @@ class HfRunner: for prompt in prompts: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids output = self.model.generate( - input_ids.cuda(), + self.wrap_device(input_ids), use_cache=True, do_sample=False, max_new_tokens=max_tokens, @@ -306,7 +316,7 @@ class HfRunner: for prompt in prompts: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids output = self.model.generate( - input_ids.cuda(), + self.wrap_device(input_ids), use_cache=True, do_sample=False, max_new_tokens=max_tokens, diff --git a/tests/models/test_aqlm.py b/tests/models/test_aqlm.py index a7abc011f5..85d74f7f5b 100644 --- a/tests/models/test_aqlm.py +++ b/tests/models/test_aqlm.py @@ -8,10 +8,13 @@ import torch from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -aqlm_not_supported = (capability < - QUANTIZATION_METHODS["aqlm"].get_min_capability()) +aqlm_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + aqlm_not_supported = (capability < + QUANTIZATION_METHODS["aqlm"].get_min_capability()) # In this test we hardcode prompts and generations for the model so we don't # need to require the AQLM package as a dependency diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 10e7c64e34..ea95e6a49f 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -5,6 +5,7 @@ This tests bigger models and use half precision. Run `pytest tests/models/test_big_models.py`. """ import pytest +import torch MODELS = [ "meta-llama/Llama-2-7b-hf", @@ -16,9 +17,14 @@ MODELS = [ # "Qwen/Qwen1.5-0.5B" # Broken, ] +#TODO: remove this after CPU float16 support ready +target_dtype = "float" +if torch.cuda.is_available(): + target_dtype = "half" + @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [32]) def test_models( hf_runner, @@ -46,7 +52,7 @@ def test_models( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", [target_dtype]) def test_model_print( vllm_runner, model: str, diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 0a5819ea3f..61aee0d0a6 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -67,10 +67,13 @@ EXPECTED_STRS_MAP = { }, } -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -fp8_not_supported = (capability < - QUANTIZATION_METHODS["fp8"].get_min_capability()) +fp8_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + fp8_not_supported = (capability < + QUANTIZATION_METHODS["fp8"].get_min_capability()) @pytest.mark.skipif(fp8_not_supported, diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index 1fc0b3f239..814471b477 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -22,10 +22,13 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true" MAX_MODEL_LEN = 1024 -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -gptq_marlin_not_supported = ( - capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) +gptq_marlin_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + gptq_marlin_not_supported = ( + capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) MODELS = [ # act_order==False, group_size=channelwise diff --git a/tests/models/test_gptq_marlin_24.py b/tests/models/test_gptq_marlin_24.py index 3e6ffb7f90..cc35ee803f 100644 --- a/tests/models/test_gptq_marlin_24.py +++ b/tests/models/test_gptq_marlin_24.py @@ -14,10 +14,13 @@ import torch from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -marlin_not_supported = (capability < - QUANTIZATION_METHODS["marlin"].get_min_capability()) +marlin_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + marlin_not_supported = ( + capability < QUANTIZATION_METHODS["marlin"].get_min_capability()) @dataclass diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 37c1664afe..8520b26718 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -19,10 +19,13 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from .utils import check_logprobs_close -capability = torch.cuda.get_device_capability() -capability = capability[0] * 10 + capability[1] -marlin_not_supported = (capability < - QUANTIZATION_METHODS["marlin"].get_min_capability()) +marlin_not_supported = True + +if torch.cuda.is_available(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + marlin_not_supported = ( + capability < QUANTIZATION_METHODS["marlin"].get_min_capability()) @dataclass