[CI/BUILD] enable intel queue for longer CPU tests (#4113)

This commit is contained in:
Yuan 2024-06-04 01:39:50 +08:00 committed by GitHub
parent cbb2f59cc8
commit cafb8e06c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 138 additions and 90 deletions

View File

@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; }
trap remove_docker_container EXIT trap remove_docker_container EXIT
remove_docker_container remove_docker_container
# Run the image and launch offline inference # Run the image
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
# offline inference
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
# Run basic model test
docker exec cpu-test bash -c "cd tests;
pip install pytest Pillow protobuf
bash ../.buildkite/download-images.sh
cd ../
pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"

View File

@ -40,6 +40,8 @@ steps:
- label: "Intel Test" - label: "Intel Test"
depends_on: ~ depends_on: ~
agents:
queue: intel
command: bash .buildkite/run-cpu-test.sh command: bash .buildkite/run-cpu-test.sh
{% for step in steps %} {% for step in steps %}

View File

@ -1,6 +1,6 @@
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
FROM ubuntu:22.04 FROM ubuntu:22.04 AS cpu-test-1
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
@ -9,6 +9,8 @@ RUN apt-get update -y \
RUN pip install --upgrade pip \ RUN pip install --upgrade pip \
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy && pip install wheel packaging ninja setuptools>=49.4.0 numpy
FROM cpu-test-1 AS build
COPY ./ /workspace/vllm COPY ./ /workspace/vllm
WORKDIR /workspace/vllm WORKDIR /workspace/vllm
@ -19,4 +21,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
WORKDIR /workspace/ WORKDIR /workspace/
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
CMD ["/bin/bash"] CMD ["/bin/bash"]

View File

@ -21,7 +21,57 @@ void rotary_embedding_impl(
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); bool flag = (embed_dim % VEC_ELEM_NUM == 0);
const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
scalar_t* qk) {
int j = 0;
for (; j < loop_upper; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(qk + out_x);
const scalar_vec_t q_y(qk + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(qk + out_x);
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(qk + out_y);
}
if (!flag) {
for (; j < embed_dim; ++j) {
const int x_index = j;
const int y_index = embed_dim + j;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const float fp32_cos = cache_ptr[x_index];
const float fp32_sin = cache_ptr[y_index];
const float fp32_q_x = qk[out_x];
const float fp32_q_y = qk[out_y];
qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
}
}
};
#pragma omp parallel for #pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
@ -32,62 +82,13 @@ void rotary_embedding_impl(
const int head_idx = i; const int head_idx = i;
const int64_t token_head = const int64_t token_head =
token_idx * query_stride + head_idx * head_size; token_idx * query_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { compute_loop(token_head, cache_ptr, query);
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(query + out_x);
const scalar_vec_t q_y(query + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(query + out_x);
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(query + out_y);
}
} }
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i; const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { compute_loop(token_head, cache_ptr, key);
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t k_x(key + out_x);
const scalar_vec_t k_y(key + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_k_x(k_x);
vec_op::FP32Vec8 fp32_k_y(k_y);
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
scalar_vec_t(out1).save(key + out_x);
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
scalar_vec_t(out2).save(key + out_y);
}
} }
} }
} }

View File

@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.multimodal import MultiModalData from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData from vllm.multimodal.image import ImageFeatureData, ImagePixelData
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
logger = init_logger(__name__) logger = init_logger(__name__)
@ -58,7 +59,8 @@ def cleanup():
with contextlib.suppress(AssertionError): with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
gc.collect() gc.collect()
torch.cuda.empty_cache() if not is_cpu():
torch.cuda.empty_cache()
@pytest.fixture() @pytest.fixture()
@ -151,6 +153,12 @@ _EMBEDDING_MODELS = [
class HfRunner: class HfRunner:
def wrap_device(self, input: any):
if not is_cpu():
return input.to("cuda")
else:
return input.to("cpu")
def __init__( def __init__(
self, self,
model_name: str, model_name: str,
@ -164,16 +172,18 @@ class HfRunner:
if model_name in _EMBEDDING_MODELS: if model_name in _EMBEDDING_MODELS:
# Lazy init required for AMD CI # Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer( self.model = self.wrap_device(
model_name, SentenceTransformer(
device="cpu", model_name,
).to(dtype=torch_dtype).cuda() device="cpu",
).to(dtype=torch_dtype))
else: else:
self.model = AutoModelForCausalLM.from_pretrained( self.model = self.wrap_device(
model_name, AutoModelForCausalLM.from_pretrained(
torch_dtype=torch_dtype, model_name,
trust_remote_code=True, torch_dtype=torch_dtype,
).cuda() trust_remote_code=True,
))
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
model_name, model_name,
@ -214,7 +224,7 @@ class HfRunner:
inputs = self.processor(**processor_kwargs) inputs = self.processor(**processor_kwargs)
output_ids = self.model.generate( output_ids = self.model.generate(
**inputs.to("cuda"), **self.wrap_device(inputs),
use_cache=True, use_cache=True,
**kwargs, **kwargs,
) )
@ -271,7 +281,7 @@ class HfRunner:
for prompt in prompts: for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate( output = self.model.generate(
input_ids.cuda(), self.wrap_device(input_ids),
use_cache=True, use_cache=True,
do_sample=False, do_sample=False,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
@ -306,7 +316,7 @@ class HfRunner:
for prompt in prompts: for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate( output = self.model.generate(
input_ids.cuda(), self.wrap_device(input_ids),
use_cache=True, use_cache=True,
do_sample=False, do_sample=False,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,

View File

@ -8,10 +8,13 @@ import torch
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
capability = torch.cuda.get_device_capability() aqlm_not_supported = True
capability = capability[0] * 10 + capability[1]
aqlm_not_supported = (capability < if torch.cuda.is_available():
QUANTIZATION_METHODS["aqlm"].get_min_capability()) capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
aqlm_not_supported = (capability <
QUANTIZATION_METHODS["aqlm"].get_min_capability())
# In this test we hardcode prompts and generations for the model so we don't # In this test we hardcode prompts and generations for the model so we don't
# need to require the AQLM package as a dependency # need to require the AQLM package as a dependency

View File

@ -5,6 +5,7 @@ This tests bigger models and use half precision.
Run `pytest tests/models/test_big_models.py`. Run `pytest tests/models/test_big_models.py`.
""" """
import pytest import pytest
import torch
MODELS = [ MODELS = [
"meta-llama/Llama-2-7b-hf", "meta-llama/Llama-2-7b-hf",
@ -16,9 +17,14 @@ MODELS = [
# "Qwen/Qwen1.5-0.5B" # Broken, # "Qwen/Qwen1.5-0.5B" # Broken,
] ]
#TODO: remove this after CPU float16 support ready
target_dtype = "float"
if torch.cuda.is_available():
target_dtype = "half"
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
def test_models( def test_models(
hf_runner, hf_runner,
@ -46,7 +52,7 @@ def test_models(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", [target_dtype])
def test_model_print( def test_model_print(
vllm_runner, vllm_runner,
model: str, model: str,

View File

@ -67,10 +67,13 @@ EXPECTED_STRS_MAP = {
}, },
} }
capability = torch.cuda.get_device_capability() fp8_not_supported = True
capability = capability[0] * 10 + capability[1]
fp8_not_supported = (capability < if torch.cuda.is_available():
QUANTIZATION_METHODS["fp8"].get_min_capability()) capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
fp8_not_supported = (capability <
QUANTIZATION_METHODS["fp8"].get_min_capability())
@pytest.mark.skipif(fp8_not_supported, @pytest.mark.skipif(fp8_not_supported,

View File

@ -22,10 +22,13 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024 MAX_MODEL_LEN = 1024
capability = torch.cuda.get_device_capability() gptq_marlin_not_supported = True
capability = capability[0] * 10 + capability[1]
gptq_marlin_not_supported = ( if torch.cuda.is_available():
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
gptq_marlin_not_supported = (
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
MODELS = [ MODELS = [
# act_order==False, group_size=channelwise # act_order==False, group_size=channelwise

View File

@ -14,10 +14,13 @@ import torch
from tests.models.utils import check_logprobs_close from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
capability = torch.cuda.get_device_capability() marlin_not_supported = True
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (capability < if torch.cuda.is_available():
QUANTIZATION_METHODS["marlin"].get_min_capability()) capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())
@dataclass @dataclass

View File

@ -19,10 +19,13 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from .utils import check_logprobs_close from .utils import check_logprobs_close
capability = torch.cuda.get_device_capability() marlin_not_supported = True
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (capability < if torch.cuda.is_available():
QUANTIZATION_METHODS["marlin"].get_min_capability()) capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())
@dataclass @dataclass