diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 262ce8e153..be0afc6305 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -39,6 +39,16 @@ FIX #xxxx (*link existing issues this PR will resolve*)
Please add documentation to docs/source/
if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
+Adding or changing kernels
+Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.
+
+ - Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
+ - Custom operations that return
Tensors
require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
+ - Use
torch.libary.opcheck()
to test the function registration and meta-function for any registered ops. See tests/kernels
for examples.
+ - When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
+ - If a new custom type is needed, see the following document: Custom Class Support in PT2.
+
+
Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required
and might not go through the PR.
diff --git a/cmake/utils.cmake b/cmake/utils.cmake
index 69998b45be..1ea6d2b0f0 100644
--- a/cmake/utils.cmake
+++ b/cmake/utils.cmake
@@ -350,6 +350,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
${GPU_INCLUDE_DIRECTORIES})
+ # TODO: is torch_python_LIBRARY needed?
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY}
${GPU_LIBRARIES})
diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp
index 9d6b1962b8..b45da1b386 100644
--- a/csrc/cpu/torch_bindings.cpp
+++ b/csrc/cpu/torch_bindings.cpp
@@ -32,8 +32,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// PagedAttention V2.
ops.def(
"paged_attention_v2("
- " Tensor! out, Tensor exp_sums, Tensor max_logits,"
- " Tensor tmp_out, Tensor query, Tensor key_cache,"
+ " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
+ " Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
@@ -122,8 +122,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Copy the cache blocks from src to dst.
cache_ops.def(
- "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
- "block_mapping) -> ()");
+ "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
+ "Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
// Reshape the key and value tensors and cache them.
diff --git a/csrc/ops.h b/csrc/ops.h
index 45a3868395..05b89e183c 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -123,9 +123,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
+torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
+ torch::Tensor& perm, c10::SymInt size_k,
+ c10::SymInt size_n, int64_t num_bits);
+
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
+torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
+ c10::SymInt size_k, c10::SymInt size_n,
+ int64_t num_bits);
+
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);
diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu
index c58216d8e0..de8d9ef2ee 100644
--- a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu
+++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu
@@ -267,3 +267,15 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
}
#endif
+
+torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
+ c10::SymInt size_k, c10::SymInt size_n,
+ int64_t num_bits) {
+ int const pack_factor = 32 / num_bits;
+ auto options = torch::TensorOptions()
+ .dtype(b_q_weight.dtype())
+ .device(b_q_weight.device());
+ return torch::empty_symint(
+ {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
+ options);
+}
diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
index c71b1bf573..70d48de12a 100644
--- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
+++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
@@ -342,3 +342,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
}
#endif
+
+torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
+ torch::Tensor& perm, c10::SymInt size_k,
+ c10::SymInt size_n, int64_t num_bits) {
+ int const pack_factor = 32 / num_bits;
+ auto options = torch::TensorOptions()
+ .dtype(b_q_weight.dtype())
+ .device(b_q_weight.device());
+ return torch::empty_symint(
+ {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
+ options);
+}
diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp
index 07b14e7a6f..57103c0936 100644
--- a/csrc/torch_bindings.cpp
+++ b/csrc/torch_bindings.cpp
@@ -36,8 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// PagedAttention V2.
ops.def(
"paged_attention_v2("
- " Tensor! out, Tensor exp_sums, Tensor max_logits,"
- " Tensor tmp_out, Tensor query, Tensor key_cache,"
+ " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
+ " Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
@@ -73,7 +73,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
// prepare_inputs advance_step
- ops.def("advance_step", &advance_step);
+ ops.def(
+ "advance_step(int num_seqs, int num_queries, int block_size, "
+ "Tensor! input_tokens, Tensor sampled_token_ids, "
+ "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
+ "Tensor block_tables) -> ()");
ops.impl("advance_step", torch::kCUDA, &advance_step);
// Layernorm
@@ -110,27 +114,56 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
- ops.def("aqlm_gemm", &aqlm_gemm);
+ ops.def(
+ "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
+ "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
+ "-> Tensor");
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
// Decompression method for AQLM.
- ops.def("aqlm_dequant", &aqlm_dequant);
+ ops.def(
+ "aqlm_dequant(Tensor codes, Tensor codebooks, "
+ "int[] codebook_partition_sizes) -> Tensor");
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
// Quantized GEMM for AWQ.
- ops.def("awq_gemm", &awq_gemm);
+ ops.def(
+ "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
+ "Tensor _zeros, int split_k_iters) -> Tensor");
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
// Dequantization for AWQ.
- ops.def("awq_dequantize", &awq_dequantize);
+ ops.def(
+ "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
+ "Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
+ // Note about marlin kernel 'workspace' arguments:
+ // Technically these should be mutable since they are modified by the kernel.
+ // But since they are set back to zero once the kernel is finished we can
+ // hand wave and say that they have no net effect.
+ //
+ // The reason to mark 'workspace' as immutable is so that they don't interfere
+ // with using ScalarType arguments in the ops. If they are marked as mutable,
+ // pytorch throws an assert in
+ // 'torch._higher_order_ops._register_effectful_op' that prevents these
+ // kernels from being torch.compile'd.
+ // See the following document for more info on custom types and ops that use
+ // custom types:
+ // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
+
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
- ops.def("marlin_gemm", &marlin_gemm);
+ ops.def(
+ "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
+ "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
- ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
+ ops.def(
+ "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
+ "Tensor b_scales, Tensor workspace, "
+ "__torch__.torch.classes._core_C.ScalarType b_q_type, "
+ "int size_m, int size_n, int size_k) -> Tensor");
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@@ -149,35 +182,55 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
// gptq_marlin Optimized Quantized GEMM for GPTQ.
- ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
+ ops.def(
+ "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
+ "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
+ "__torch__.torch.classes._core_C.ScalarType b_q_type, "
+ "int size_m, int size_n, int size_k, bool is_k_full, "
+ "bool has_zp, bool use_fp32_reduce) -> Tensor");
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
// gptq_marlin repack from GPTQ.
- ops.def("gptq_marlin_repack", &gptq_marlin_repack);
+ ops.def(
+ "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
+ "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
+ ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
// awq_marlin repack from AWQ.
- ops.def("awq_marlin_repack", &awq_marlin_repack);
+ ops.def(
+ "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
+ "SymInt size_n, int num_bits) -> Tensor");
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
+ ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
// Dequantization for GGML.
- ops.def("ggml_dequantize", &ggml_dequantize);
+ ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
// mmvq kernel for GGML.
- ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
+ ops.def(
+ "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
+ "-> Tensor");
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
// mmq kernel for GGML.
- ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
+ ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
- ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
+ ops.def(
+ "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
+ "Tensor! workspace, int num_bits, int size_m, int size_n, "
+ "int size_k) -> Tensor");
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
// marlin_qqq_gemm for QQQ.
- ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
+ ops.def(
+ "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
+ "Tensor s_tok, Tensor s_ch, Tensor s_group, "
+ "Tensor! workspace, int size_m, int size_n, "
+ "int size_k) -> Tensor");
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
@@ -199,16 +252,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
- ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
- ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
- &cutlass_scaled_mm_supports_fp8);
+ ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
+ ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
+
// Mamba selective scan kernel
ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
- "Tensor? index_, Tensor? x) -> Tensor[]");
+ "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def(
@@ -230,7 +283,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
// Quantized GEMM for GPTQ.
- ops.def("gptq_gemm", &gptq_gemm);
+ // Note: even though the C++ inferred schema is correct for this op, it seems
+ // to prevent the meta function registry.
+ ops.def(
+ "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
+ "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
+ "-> Tensor");
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
// Post processing for GPTQ.
@@ -250,8 +308,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
- "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
- "scale, Tensor? scale_ub) -> "
+ "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
+ "Tensor! scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
@@ -288,8 +346,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Copy the cache blocks from src to dst.
cache_ops.def(
- "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
- "block_mapping) -> ()");
+ "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
+ "Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
// Reshape the key and value tensors and cache them.
@@ -314,8 +372,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Convert the key and value cache to fp8 data type.
cache_ops.def(
- "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
- "kv_cache_dtype) -> ()");
+ "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
+ "str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
}
@@ -323,24 +381,28 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils
// Gets the specified device attribute.
- cuda_utils.def("get_device_attribute", &get_device_attribute);
- cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
+ cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
+ cuda_utils.impl("get_device_attribute", &get_device_attribute);
// Gets the maximum shared memory per block device attribute.
- cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
- &get_max_shared_memory_per_block_device_attribute);
+ cuda_utils.def(
+ "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
- torch::kCUDA,
&get_max_shared_memory_per_block_device_attribute);
}
#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
- custom_ar.def("init_custom_ar", &init_custom_ar);
+ custom_ar.def(
+ "init_custom_ar(Tensor meta, Tensor rank_data, "
+ "str[] handles, int[] offsets, int rank, "
+ "bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
- custom_ar.def("should_custom_ar", &should_custom_ar);
+ custom_ar.def(
+ "should_custom_ar(Tensor inp, int max_size, int world_size, "
+ "bool full_nvlink) -> bool");
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
@@ -352,21 +414,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
custom_ar.def("dispose", &dispose);
- custom_ar.impl("dispose", torch::kCPU, &dispose);
-
custom_ar.def("meta_size", &meta_size);
- custom_ar.impl("meta_size", torch::kCPU, &meta_size);
- custom_ar.def("register_buffer", ®ister_buffer);
+ custom_ar.def(
+ "register_buffer(int fa, Tensor t, str[] handles, "
+ "int[] offsets) -> ()");
custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
- custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
- &get_graph_buffer_ipc_meta);
-
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
- custom_ar.impl("register_graph_buffers", torch::kCPU,
- ®ister_graph_buffers);
}
#endif
diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py
index 38b0477063..ed050ce851 100644
--- a/tests/kernels/test_activation.py
+++ b/tests/kernels/test_activation.py
@@ -3,8 +3,10 @@ from typing import Type
import pytest
import torch
+from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
- NewGELU, SiluAndMul)
+ NewGELU, QuickGELU,
+ SiluAndMul)
from .allclose_default import get_default_atol, get_default_rtol
@@ -39,18 +41,28 @@ def test_act_and_mul(
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu":
layer = SiluAndMul()
+ fn = torch.ops._C.silu_and_mul
elif activation == "gelu":
layer = GeluAndMul(approximate="none")
+ fn = torch.ops._C.gelu_and_mul
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
+ fn = torch.ops._C.gelu_tanh_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
+ d = x.shape[-1] // 2
+ output_shape = (x.shape[:-1] + (d, ))
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
+ opcheck(fn, (out, x))
-@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
+
+@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
+ (NewGELU, torch.ops._C.gelu_new),
+ (QuickGELU, torch.ops._C.gelu_quick)])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@@ -70,10 +82,14 @@ def test_activation(
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
- layer = activation()
+ layer = activation[0]()
+ fn = activation[1]
out = layer(x)
ref_out = layer.forward_native(x)
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
+
+ out = torch.empty_like(x)
+ opcheck(fn, (out, x))
diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py
index 8aa2d4a53a..7995f11f19 100644
--- a/tests/kernels/test_attention.py
+++ b/tests/kernels/test_attention.py
@@ -6,6 +6,7 @@ import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
+from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
@@ -198,6 +199,13 @@ def test_paged_attention(
k_scale,
v_scale,
)
+
+ opcheck(torch.ops._C.paged_attention_v1,
+ (output, query, key_cache, value_cache, num_kv_heads, scale,
+ block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
+ kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
+ cond=(head_size == HEAD_SIZES[0]))
+
elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
@@ -230,6 +238,14 @@ def test_paged_attention(
k_scale,
v_scale,
)
+
+ opcheck(torch.ops._C.paged_attention_v2,
+ (output, exp_sums, max_logits, tmp_output, query, key_cache,
+ value_cache, num_kv_heads, scale, block_tables, seq_lens,
+ block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
+ k_scale, v_scale, 0, 0, 0, 64, 0),
+ cond=(head_size == HEAD_SIZES[0]))
+
else:
raise AssertionError(f"Unknown version: {version}")
diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py
index 71d1835916..19402a337b 100644
--- a/tests/kernels/test_cache.py
+++ b/tests/kernels/test_cache.py
@@ -4,6 +4,7 @@ from typing import List, Tuple
import pytest
import torch
+from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
@@ -87,6 +88,11 @@ def test_copy_blocks(
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device=device).view(-1, 2)
+
+ opcheck(torch.ops._C_cache_ops.copy_blocks,
+ (key_caches, value_caches, block_mapping_tensor),
+ test_utils=DEFAULT_OPCHECK_TEST_UTILS,
+ cond=(head_size == HEAD_SIZES[0]))
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
# Run the reference implementation.
@@ -162,6 +168,10 @@ def test_reshape_and_cache(
k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel.
+ opcheck(torch.ops._C_cache_ops.reshape_and_cache,
+ (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
+ k_scale, v_scale),
+ cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)
@@ -269,6 +279,10 @@ def test_reshape_and_cache_flash(
k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel.
+ opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
+ (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
+ k_scale, v_scale),
+ cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale)
@@ -366,6 +380,14 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
+ do_opcheck = (head_size == HEAD_SIZES[0])
+ opcheck(torch.ops._C_cache_ops.swap_blocks,
+ (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
+ cond=do_opcheck)
+ opcheck(torch.ops._C_cache_ops.swap_blocks,
+ (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
+ cond=do_opcheck)
+
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py
index e818651fe9..d1f0524f83 100644
--- a/tests/kernels/test_cutlass.py
+++ b/tests/kernels/test_cutlass.py
@@ -7,6 +7,7 @@ from typing import Optional, Type
import pytest
import torch
+from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
@@ -108,6 +109,9 @@ def cutlass_int8_gemm_helper(m: int,
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
+ opcheck(torch.ops._C.cutlass_scaled_mm,
+ (out, a, b, scale_a, scale_b, bias))
+
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
@pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
@@ -341,6 +345,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
+ if azp_per_token:
+ opcheck(torch.ops._C.cutlass_scaled_mm_azp,
+ (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
+ func_bias))
+ else:
+ opcheck(torch.ops._C.cutlass_scaled_mm_azp,
+ (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
+ func_bias))
+
# Test working with a subset of A and B
def test_cutlass_subset():
diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py
index 7376dcaf60..a82ecb0264 100644
--- a/tests/kernels/test_int8_quant.py
+++ b/tests/kernels/test_int8_quant.py
@@ -2,6 +2,7 @@ import pytest
import torch
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
+from tests.kernels.utils import opcheck
from vllm._custom_ops import scaled_int8_quant
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -12,6 +13,16 @@ SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
+def opcheck_int8_quant(output, input, scale=None):
+ if scale is not None:
+ opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale))
+ else:
+ scale = torch.empty((input.numel() // input.shape[-1], 1),
+ device=input.device,
+ dtype=torch.float32)
+ opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale))
+
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@@ -34,6 +45,8 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
ops_out, ref_out, atol=1,
rtol=0.0) # big atol to account for rounding errors
+ opcheck_int8_quant(ops_out, x)
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@@ -58,3 +71,5 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
torch.testing.assert_close(
out1, out2, atol=1,
rtol=0.0) # big atol to account for rounding errors
+
+ opcheck_int8_quant(out2, x, scale)
diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py
index 21bc38d67b..6eaf67ec75 100644
--- a/tests/kernels/test_layernorm.py
+++ b/tests/kernels/test_layernorm.py
@@ -1,6 +1,7 @@
import pytest
import torch
+from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -52,3 +53,10 @@ def test_rms_norm(
torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
else:
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
+
+ if residual is not None:
+ opcheck(torch.ops._C.fused_add_rms_norm,
+ (x, residual, layer.weight.data, layer.variance_epsilon))
+ else:
+ opcheck(torch.ops._C.rms_norm,
+ (out, x, layer.weight.data, layer.variance_epsilon))
diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py
index dadf594409..ce65aaef60 100644
--- a/tests/kernels/test_machete_gemm.py
+++ b/tests/kernels/test_machete_gemm.py
@@ -9,6 +9,7 @@ from typing import Optional, Tuple
import pytest
import torch
+from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights)
@@ -76,6 +77,8 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype)
+ opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
+
return w_ref, w_q_machete, w_s, w_zp
@@ -146,6 +149,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule=schedule,
)
+ opcheck(torch.ops._C.machete_gemm,
+ (a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
+ w_zp, w_s), group_size, None, None, None, schedule))
+
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py
index 18b66abe7b..721d3a6a81 100644
--- a/tests/kernels/test_marlin_gemm.py
+++ b/tests/kernels/test_marlin_gemm.py
@@ -5,6 +5,7 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
import pytest
import torch
+from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
@@ -73,12 +74,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
act_order, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
- size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
- print(f"MNK = {size_m} {size_n} {size_k}")
-
# Filter act_order
if act_order:
if group_size == -1:
@@ -112,6 +110,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
+ opcheck(torch.ops._C.gptq_marlin_repack,
+ (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
+
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack(
q_w_gptq,
@@ -137,12 +138,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
- size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
- print(f"MNK = {size_m} {size_n} {size_k}")
-
# Normalize group_size
if group_size == -1:
group_size = size_k
@@ -165,6 +163,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
+ opcheck(torch.ops._C.awq_marlin_repack,
+ (q_w_awq, size_k, size_n, quant_type.size_bits))
+
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq,
@@ -204,9 +205,6 @@ def test_gptq_marlin_gemm(
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
- print(f"MNK = {size_m} {size_n} {size_k}")
- print(f"groupsize = {group_size}")
-
if act_order:
if group_size == -1:
return
@@ -224,6 +222,13 @@ def test_gptq_marlin_gemm(
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
+ opcheck(
+ torch.ops._C.gptq_marlin_gemm,
+ (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
+ workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
+ a_input.shape[1], is_k_full, False, use_fp32_reduce),
+ test_utils=DEFAULT_OPCHECK_TEST_UTILS)
+
output = ops.gptq_marlin_gemm(
a_input,
marlin_q_w,
@@ -245,7 +250,6 @@ def test_gptq_marlin_gemm(
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
- print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
@@ -265,9 +269,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
- print(f"MNK = {size_m} {size_n} {size_k}")
- print(f"groupsize = {group_size}")
-
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
@@ -279,6 +280,12 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
output_ref = torch.matmul(a_input, w_24_ref)
+ opcheck(torch.ops._C.gptq_marlin_24_gemm,
+ (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
+ workspace_24.scratch, quant_type, a_input.shape[0],
+ b_weight.shape[1], a_input.shape[1]),
+ test_utils=DEFAULT_OPCHECK_TEST_UTILS)
+
output = ops.gptq_marlin_24_gemm(
a_input,
marlin_24_q_w_comp,
@@ -294,7 +301,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
- print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
@@ -321,9 +327,6 @@ def test_fp8_marlin_gemm(
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
- print(f"MNK = {size_m} {size_n} {size_k}")
- print(f"groupsize = {group_size}")
-
a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)
@@ -353,6 +356,10 @@ def test_fp8_marlin_gemm(
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
+ opcheck(torch.ops._C.fp8_marlin_gemm,
+ (a_input, marlin_qweight, marlin_scales, workspace.scratch,
+ num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
+
output = ops.fp8_marlin_gemm(
a=a_input,
b_q_weight=marlin_qweight,
@@ -368,7 +375,6 @@ def test_fp8_marlin_gemm(
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
- print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
@@ -396,9 +402,6 @@ def test_awq_marlin_gemm(
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
- print(f"MNK = {size_m} {size_n} {size_k}")
- print(f"groupsize = {group_size}")
-
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
@@ -434,7 +437,6 @@ def test_awq_marlin_gemm(
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
- print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
@@ -460,9 +462,6 @@ def test_marlin_qqq_gemm(
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
- print(f"MNK = {size_m} {size_n} {size_k}")
- print(f"groupsize = {group_size}")
-
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
@@ -479,6 +478,11 @@ def test_marlin_qqq_gemm(
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_MAX_PARALLEL)
+ opcheck(torch.ops._C.marlin_qqq_gemm,
+ (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
+ marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
+ b_weight.shape[1], a_input.shape[1]))
+
output = ops.marlin_qqq_gemm(
q_a,
marlin_qqq_q_w,
@@ -495,6 +499,5 @@ def test_marlin_qqq_gemm(
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
- print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py
index 3f8f650203..dbddd69c07 100644
--- a/tests/kernels/utils.py
+++ b/tests/kernels/utils.py
@@ -3,7 +3,8 @@
import itertools
import random
from numbers import Number
-from typing import Any, List, NamedTuple, Optional, Tuple, Union
+from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
+ Union)
import pytest
import torch
@@ -13,6 +14,21 @@ from vllm.attention.backends.xformers import XFormersBackend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
+# For now, disable "test_aot_dispatch_dynamic" since there are some
+# bugs related to this test in PyTorch 2.4.
+DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
+ "test_schema",
+ "test_autograd_registration",
+ "test_faketensor",
+)
+
+ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
+ "test_schema",
+ "test_autograd_registration",
+ "test_faketensor",
+ "test_aot_dispatch_dynamic",
+)
+
class QKVInputs(NamedTuple):
'''
@@ -926,3 +942,19 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
ideal_output = test_params.packed_qkvo.ideal_output
torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output))
+
+
+def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
+ torch._library.custom_ops.CustomOpDef],
+ args: Tuple[Any, ...],
+ kwargs: Optional[Dict[str, Any]] = None,
+ *,
+ test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
+ raise_exception: bool = True,
+ cond: bool = True) -> Dict[str, str]:
+ return torch.library.opcheck(
+ op,
+ args,
+ kwargs,
+ test_utils=test_utils,
+ raise_exception=raise_exception) if cond else {}
diff --git a/tests/models/test_aqlm.py b/tests/models/test_aqlm.py
index 80034a5118..de46032113 100644
--- a/tests/models/test_aqlm.py
+++ b/tests/models/test_aqlm.py
@@ -7,26 +7,6 @@ import pytest
from tests.quantization.utils import is_quant_method_supported
-# In this test we hardcode prompts and generations for the model so we don't
-# need to require the AQLM package as a dependency
-example_prompts = [
- 'vLLM is a high-throughput and memory-efficient inference and serving '
- 'engine for LLMs.\n',
- 'Briefly describe the major milestones in the development of artificial '
- 'intelligence from 1950 to 2020.\n',
- 'Compare and contrast artificial intelligence with human intelligence in '
- 'terms of processing information.\n',
- 'Describe the basic components of a neural network and how it can be '
- 'trained.\n',
- 'Write a short story about a robot that dreams for the first time.\n',
- 'Analyze the impact of the COVID-19 pandemic on global economic structures '
- 'and future business models.\n',
- 'Explain the cultural significance of the Mona Lisa painting, and how its '
- 'perception might vary in Western versus Eastern societies.\n',
- "Translate the following English sentence into Japanese, French, and "
- "Swahili: 'The early bird catches the worm.'\n"
-]
-
# These ground truth generations were generated using `transformers==4.38.1
# aqlm==1.1.0 torch==2.2.0`
# and the below code:
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 151cdbee8e..7a9061526e 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -204,6 +204,22 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_g_idx, use_exllama, bit)
+# TODO: has to be a better way to do this
+try:
+ torch.ops._C.gptq_gemm # noqa B018
+
+ @torch.library.register_fake("_C::gptq_gemm")
+ def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+ b_gptq_qzeros: torch.Tensor,
+ b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
+ use_exllama: bool, bit: int) -> torch.Tensor:
+ return torch.empty((a.size(0), b_q_weight.size(1)),
+ dtype=a.dtype,
+ device=a.device)
+except Exception:
+ pass
+
+
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
@@ -227,6 +243,194 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k)
+# TODO: has to be a better way to do this
+try:
+ torch.ops._C.gptq_marlin_24_gemm # noqa B018
+
+ @torch.library.register_fake("_C::gptq_marlin_24_gemm")
+ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+ b_meta: torch.Tensor, b_scales: torch.Tensor,
+ workspace: torch.Tensor,
+ b_q_type: ScalarType, size_m: int,
+ size_n: int, size_k: int) -> torch.Tensor:
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
+
+ @torch.library.register_fake("_C::gptq_marlin_gemm")
+ def _gptq_marlin_gemm_fake(a: torch.Tensor,
+ b_q_weight: torch.Tensor,
+ b_scales: torch.Tensor,
+ b_zeros: torch.Tensor,
+ g_idx: torch.Tensor,
+ perm: torch.Tensor,
+ workspace: torch.Tensor,
+ b_q_type: ScalarType,
+ size_m: int,
+ size_n: int,
+ size_k: int,
+ is_k_full: bool,
+ has_zp: bool = False,
+ use_fp32_reduce: bool = False) -> torch.Tensor:
+ return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
+
+ @torch.library.register_fake("_C::ggml_dequantize")
+ def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
+ n: int) -> torch.Tensor:
+ return torch.empty((m, n), dtype=torch.float16, device=W.device)
+
+ @torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
+ def _ggml_mul_mat_vec_a8_fake(
+ W: torch.Tensor,
+ X: torch.Tensor,
+ quant_type: int,
+ row: int,
+ ) -> torch.Tensor:
+ return torch.empty((1, row), dtype=torch.float16, device=W.device)
+
+ @torch.library.register_fake("_C::ggml_mul_mat_a8")
+ def _ggml_mul_mat_a8_fake(
+ W: torch.Tensor,
+ X: torch.Tensor,
+ quant_type: int,
+ row: int,
+ ) -> torch.Tensor:
+ batch = X.size(0)
+ return torch.empty((batch, row), dtype=torch.float16, device=W.device)
+
+ @torch.library.register_fake("_C::marlin_qqq_gemm")
+ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+ s_tok: torch.Tensor, s_ch: torch.Tensor,
+ s_group: torch.Tensor, workspace: torch.Tensor,
+ size_m: int, size_n: int,
+ size_k: int) -> torch.Tensor:
+ return torch.empty((size_m, size_n),
+ dtype=torch.float16,
+ device=a.device)
+
+ @torch.library.register_fake("_C::marlin_gemm")
+ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+ b_scales: torch.Tensor, workspace: torch.Tensor,
+ size_m: int, size_n: int,
+ size_k: int) -> torch.Tensor:
+ return torch.empty((size_m, size_n),
+ dtype=torch.float16,
+ device=a.device)
+
+ @torch.library.register_fake("_C::awq_dequantize")
+ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
+ zeros: torch.Tensor, split_k_iters: int, thx: int,
+ thy: int) -> torch.Tensor:
+ in_c = qweight.size(0)
+ qout_c = qweight.size(1)
+ out_c = qout_c * 8
+ return torch.empty((in_c, out_c),
+ dtype=scales.dtype,
+ device=scales.device)
+
+ @torch.library.register_fake("_C::awq_gemm")
+ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
+ qzeros: torch.Tensor, scales: torch.Tensor,
+ split_k_iters: int) -> torch.Tensor:
+ num_in_feats = input.size(0)
+ return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
+ dtype=input.dtype,
+ device=input.device).sum(0)
+
+ @torch.library.register_fake("_C::aqlm_gemm")
+ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
+ codebooks: torch.Tensor, scales: torch.Tensor,
+ codebook_partition_sizes: List[int],
+ bias: Optional[torch.Tensor]) -> torch.Tensor:
+ out_features = codes.size(0) * codebooks.size(2)
+ flat_input = input.reshape((-1, input.size(-1)))
+ flat_output = torch.empty((flat_input.size(0), out_features),
+ dtype=input.dtype,
+ device=input.device)
+
+ output_sizes = list(input.shape)
+ output_sizes.pop()
+ output_sizes.append(-1)
+ return flat_output.reshape(tuple(output_sizes))
+
+ @torch.library.register_fake("_C::aqlm_dequant")
+ def _aqlm_dequant_fake(
+ codes: torch.Tensor, codebooks: torch.Tensor,
+ codebook_partition_sizes: List[int]) -> torch.Tensor:
+ in_features = codes.size(1) * 8
+ out_features = codes.size(0)
+ return torch.empty((out_features, in_features),
+ dtype=codebooks.dtype,
+ device=codebooks.device)
+
+ @torch.library.register_fake("_C::fp8_marlin_gemm")
+ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+ b_scales: torch.Tensor, workspace: torch.Tensor,
+ num_bits: int, size_m: int, size_n: int,
+ size_k: int) -> torch.Tensor:
+ return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
+
+ @torch.library.register_fake("_C::machete_gemm")
+ def machete_gemm_fake(
+ a: torch.Tensor,
+ b_q: torch.
+ Tensor, # Should be the tensor returned by machete_prepack_B
+ b_type: ScalarType,
+ b_scales: Optional[torch.Tensor] = None,
+ b_zeros: Optional[torch.Tensor] = None,
+ b_group_size: Optional[int] = None,
+ c: Optional[torch.Tensor] = None,
+ alpha: Optional[float] = None,
+ beta: Optional[float] = None,
+ schedule: Optional[str] = None,
+ ) -> torch.Tensor:
+ m = a.size(0)
+ n = b_q.size(1)
+ return torch.empty((m, n), device=a.device, dtype=a.dtype)
+
+ @torch.library.register_fake("_C::machete_prepack_B")
+ def machete_prepack_B_fake(b_q_weight: torch.Tensor,
+ b_type: ScalarType) -> torch.Tensor:
+ return torch.empty_like(b_q_weight)
+
+ @torch.library.register_fake("_C::causal_conv1d_fwd")
+ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
+ bias_: Optional[torch.Tensor],
+ seq_idx_: Optional[torch.Tensor],
+ initial_states_: Optional[torch.Tensor],
+ final_states_out_: Optional[torch.Tensor],
+ silu_activation: bool) -> torch.Tensor:
+ return torch.empty_like(x)
+
+ @torch.library.register_fake("_C::causal_conv1d_update")
+ def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias_: Optional[torch.Tensor],
+ silu_activation: bool) -> torch.Tensor:
+ return torch.empty_like(x)
+
+ @torch.library.register_fake("_C::selective_scan_fwd")
+ def selective_scan_fwd_fake(
+ u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
+ B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
+ z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
+ delta_softplus: bool, index_: Optional[torch.Tensor],
+ x: Optional[torch.Tensor]) -> List[torch.Tensor]:
+ a = torch.empty_like(u)
+ if x is not None:
+ b = x
+ else:
+ b = torch.empty((u.size(0), u.size(1), A.size(1)),
+ dtype=u.dtype,
+ device=u.device)
+ if z_ is not None:
+ c = torch.empty_like(z_)
+ return [a, b, c]
+ else:
+ return [a, b]
+
+except Exception:
+ pass
+
+
# cutlass
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
diff --git a/vllm/envs.py b/vllm/envs.py
index ed45047e9f..9e34b3d08b 100644
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -203,6 +203,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),
+ # Internal flag to enable Dynamo fullgraph capture
+ "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
+ lambda: bool(
+ os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
+
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":
diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py
index 29dd09afac..9b7cc22869 100644
--- a/vllm/model_executor/models/jamba.py
+++ b/vllm/model_executor/models/jamba.py
@@ -733,7 +733,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
- destination_indices = set([range(batch_size)])
+ destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index cf8bc3e6a1..0cfca0671b 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -75,6 +75,10 @@ _NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
+# For now, bump up cache limits for recompilations during CUDA graph warmups.
+torch._dynamo.config.cache_size_limit = 128
+torch._dynamo.config.accumulated_cache_size_limit = 128
+
@dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase):
@@ -1060,9 +1064,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
- self.model = torch.compile(self.model,
- fullgraph=True,
- backend="eager")
+ self.model = torch.compile(
+ self.model,
+ fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
+ backend="eager")
def save_sharded_state(
self,
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index 0ff559a9af..52092dc2dc 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -166,6 +166,7 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
+ gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else: