[Kernel][Misc] register ops to prevent graph breaks (#6917)

Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
bnellnm 2024-09-11 15:52:19 -04:00 committed by GitHub
parent 7015417fd4
commit 73202dbe77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 528 additions and 102 deletions

View File

@ -39,6 +39,16 @@ FIX #xxxx (*link existing issues this PR will resolve*)
<li>Please add documentation to <code>docs/source/</code> if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.</li> <li>Please add documentation to <code>docs/source/</code> if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.</li>
</ul> </ul>
<h3>Adding or changing kernels</h3>
<p>Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.</p>
<ul>
<li>Make sure custom ops are registered following PyTorch guidelines: <a href="https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial">Custom C++ and CUDA Operators</a> and <a href="https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU">The Custom Operators Manual</a></li>
<li>Custom operations that return <code>Tensors</code> require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.</li>
<li>Use <a href="https://pytorch.org/docs/stable/library.html#torch.library.opcheck"><code>torch.libary.opcheck()</code></a> to test the function registration and meta-function for any registered ops. See <code>tests/kernels</code> for examples.</li>
<li>When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.</li>
<li>If a new custom type is needed, see the following document: <a href="https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA">Custom Class Support in PT2</a>.
</ul>
<h3>Notes for Large Changes</h3> <h3>Notes for Large Changes</h3>
<p>Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with <code>rfc-required</code> and might not go through the PR.</p> <p>Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with <code>rfc-required</code> and might not go through the PR.</p>

View File

@ -350,6 +350,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
${GPU_INCLUDE_DIRECTORIES}) ${GPU_INCLUDE_DIRECTORIES})
# TODO: is torch_python_LIBRARY needed?
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY} target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY}
${GPU_LIBRARIES}) ${GPU_LIBRARIES})

View File

@ -32,8 +32,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// PagedAttention V2. // PagedAttention V2.
ops.def( ops.def(
"paged_attention_v2(" "paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits," " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache," " Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
@ -122,8 +122,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Copy the cache blocks from src to dst. // Copy the cache blocks from src to dst.
cache_ops.def( cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"block_mapping) -> ()"); "Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks); cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
// Reshape the key and value tensors and cache them. // Reshape the key and value tensors and cache them.

View File

@ -123,9 +123,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n, int64_t size_k, int64_t size_n,
int64_t num_bits); int64_t num_bits);
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits); int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits);
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n); int64_t n);

View File

@ -267,3 +267,15 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
} }
#endif #endif
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}

View File

@ -342,3 +342,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
} }
#endif #endif
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}

View File

@ -36,8 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// PagedAttention V2. // PagedAttention V2.
ops.def( ops.def(
"paged_attention_v2(" "paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits," " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache," " Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
@ -73,7 +73,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
// prepare_inputs advance_step // prepare_inputs advance_step
ops.def("advance_step", &advance_step); ops.def(
"advance_step(int num_seqs, int num_queries, int block_size, "
"Tensor! input_tokens, Tensor sampled_token_ids, "
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
"Tensor block_tables) -> ()");
ops.impl("advance_step", torch::kCUDA, &advance_step); ops.impl("advance_step", torch::kCUDA, &advance_step);
// Layernorm // Layernorm
@ -110,27 +114,56 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization ops // Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
// Quantized GEMM for AQLM. // Quantized GEMM for AQLM.
ops.def("aqlm_gemm", &aqlm_gemm); ops.def(
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
"-> Tensor");
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
// Decompression method for AQLM. // Decompression method for AQLM.
ops.def("aqlm_dequant", &aqlm_dequant); ops.def(
"aqlm_dequant(Tensor codes, Tensor codebooks, "
"int[] codebook_partition_sizes) -> Tensor");
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
// Quantized GEMM for AWQ. // Quantized GEMM for AWQ.
ops.def("awq_gemm", &awq_gemm); ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, int split_k_iters) -> Tensor");
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
// Dequantization for AWQ. // Dequantization for AWQ.
ops.def("awq_dequantize", &awq_dequantize); ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
// Note about marlin kernel 'workspace' arguments:
// Technically these should be mutable since they are modified by the kernel.
// But since they are set back to zero once the kernel is finished we can
// hand wave and say that they have no net effect.
//
// The reason to mark 'workspace' as immutable is so that they don't interfere
// with using ScalarType arguments in the ops. If they are marked as mutable,
// pytorch throws an assert in
// 'torch._higher_order_ops._register_effectful_op' that prevents these
// kernels from being torch.compile'd.
// See the following document for more info on custom types and ops that use
// custom types:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
// Marlin (Dense) Optimized Quantized GEMM for GPTQ. // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
ops.def("marlin_gemm", &marlin_gemm); ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); ops.def(
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
"Tensor b_scales, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k) -> Tensor");
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper. // Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
@ -149,35 +182,55 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
// gptq_marlin Optimized Quantized GEMM for GPTQ. // gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); ops.def(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
// gptq_marlin repack from GPTQ. // gptq_marlin repack from GPTQ.
ops.def("gptq_marlin_repack", &gptq_marlin_repack); ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
// awq_marlin repack from AWQ. // awq_marlin repack from AWQ.
ops.def("awq_marlin_repack", &awq_marlin_repack); ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor");
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
// Dequantization for GGML. // Dequantization for GGML.
ops.def("ggml_dequantize", &ggml_dequantize); ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
// mmvq kernel for GGML. // mmvq kernel for GGML.
ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8); ops.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
"-> Tensor");
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8); ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
// mmq kernel for GGML. // mmq kernel for GGML.
ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8); ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only. // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm); ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, int size_m, int size_n, "
"int size_k) -> Tensor");
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
// marlin_qqq_gemm for QQQ. // marlin_qqq_gemm for QQQ.
ops.def("marlin_qqq_gemm", &marlin_qqq_gemm); ops.def(
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, int size_m, int size_n, "
"int size_k) -> Tensor");
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm); ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
@ -199,16 +252,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Check if cutlass scaled_mm is supported for CUDA devices of the given // Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability // capability
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA, ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
&cutlass_scaled_mm_supports_fp8);
// Mamba selective scan kernel // Mamba selective scan kernel
ops.def( ops.def(
"selective_scan_fwd(Tensor! u, Tensor! delta," "selective_scan_fwd(Tensor! u, Tensor! delta,"
"Tensor! A, Tensor! B, Tensor! C," "Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor? z_, Tensor? delta_bias_," "Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
"bool delta_softplus," "bool delta_softplus,"
"Tensor? index_, Tensor? x) -> Tensor[]"); "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
ops.def( ops.def(
@ -230,7 +283,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif #endif
// Quantized GEMM for GPTQ. // Quantized GEMM for GPTQ.
ops.def("gptq_gemm", &gptq_gemm); // Note: even though the C++ inferred schema is correct for this op, it seems
// to prevent the meta function registry.
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
"-> Tensor");
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
// Post processing for GPTQ. // Post processing for GPTQ.
@ -250,8 +308,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute dynamic-per-token FP8 quantized tensor and scaling factor. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def( ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! " "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
"scale, Tensor? scale_ub) -> " "Tensor! scale, Tensor? scale_ub) -> "
"()"); "()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant); &dynamic_per_token_scaled_fp8_quant);
@ -288,8 +346,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Copy the cache blocks from src to dst. // Copy the cache blocks from src to dst.
cache_ops.def( cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"block_mapping) -> ()"); "Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks); cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
// Reshape the key and value tensors and cache them. // Reshape the key and value tensors and cache them.
@ -314,8 +372,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Convert the key and value cache to fp8 data type. // Convert the key and value cache to fp8 data type.
cache_ops.def( cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str " "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
"kv_cache_dtype) -> ()"); "str kv_cache_dtype) -> ()");
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
} }
@ -323,24 +381,28 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils // Cuda utils
// Gets the specified device attribute. // Gets the specified device attribute.
cuda_utils.def("get_device_attribute", &get_device_attribute); cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute); cuda_utils.impl("get_device_attribute", &get_device_attribute);
// Gets the maximum shared memory per block device attribute. // Gets the maximum shared memory per block device attribute.
cuda_utils.def("get_max_shared_memory_per_block_device_attribute", cuda_utils.def(
&get_max_shared_memory_per_block_device_attribute); "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
torch::kCUDA,
&get_max_shared_memory_per_block_device_attribute); &get_max_shared_memory_per_block_device_attribute);
} }
#ifndef USE_ROCM #ifndef USE_ROCM
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels // Custom all-reduce kernels
custom_ar.def("init_custom_ar", &init_custom_ar); custom_ar.def(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def("should_custom_ar", &should_custom_ar); custom_ar.def(
"should_custom_ar(Tensor inp, int max_size, int world_size, "
"bool full_nvlink) -> bool");
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
@ -352,21 +414,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
custom_ar.def("dispose", &dispose); custom_ar.def("dispose", &dispose);
custom_ar.impl("dispose", torch::kCPU, &dispose);
custom_ar.def("meta_size", &meta_size); custom_ar.def("meta_size", &meta_size);
custom_ar.impl("meta_size", torch::kCPU, &meta_size);
custom_ar.def("register_buffer", &register_buffer); custom_ar.def(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()");
custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer); custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
&get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers); custom_ar.def("register_graph_buffers", &register_graph_buffers);
custom_ar.impl("register_graph_buffers", torch::kCPU,
&register_graph_buffers);
} }
#endif #endif

View File

@ -3,8 +3,10 @@ from typing import Type
import pytest import pytest
import torch import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul,
NewGELU, SiluAndMul) NewGELU, QuickGELU,
SiluAndMul)
from .allclose_default import get_default_atol, get_default_rtol from .allclose_default import get_default_atol, get_default_rtol
@ -39,18 +41,28 @@ def test_act_and_mul(
x = torch.randn(num_tokens, 2 * d, dtype=dtype) x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu": if activation == "silu":
layer = SiluAndMul() layer = SiluAndMul()
fn = torch.ops._C.silu_and_mul
elif activation == "gelu": elif activation == "gelu":
layer = GeluAndMul(approximate="none") layer = GeluAndMul(approximate="none")
fn = torch.ops._C.gelu_and_mul
elif activation == "gelu_tanh": elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh") layer = GeluAndMul(approximate="tanh")
fn = torch.ops._C.gelu_tanh_and_mul
out = layer(x) out = layer(x)
ref_out = layer.forward_native(x) ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch # The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison. # implementations, so we can do exact comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
opcheck(fn, (out, x))
@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
(NewGELU, torch.ops._C.gelu_new),
(QuickGELU, torch.ops._C.gelu_quick)])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D) @pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@ -70,10 +82,14 @@ def test_activation(
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype) x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation() layer = activation[0]()
fn = activation[1]
out = layer(x) out = layer(x)
ref_out = layer.forward_native(x) ref_out = layer.forward_native(x)
torch.testing.assert_close(out, torch.testing.assert_close(out,
ref_out, ref_out,
atol=get_default_atol(out), atol=get_default_atol(out),
rtol=get_default_rtol(out)) rtol=get_default_rtol(out))
out = torch.empty_like(x)
opcheck(fn, (out, x))

View File

@ -6,6 +6,7 @@ import torch
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip from vllm.utils import get_max_shared_memory_bytes, is_hip
@ -198,6 +199,13 @@ def test_paged_attention(
k_scale, k_scale,
v_scale, v_scale,
) )
opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
elif version == "v2": elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0 assert PARTITION_SIZE % block_size == 0
@ -230,6 +238,14 @@ def test_paged_attention(
k_scale, k_scale,
v_scale, v_scale,
) )
opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query, key_cache,
value_cache, num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
else: else:
raise AssertionError(f"Unknown version: {version}") raise AssertionError(f"Unknown version: {version}")

View File

@ -4,6 +4,7 @@ from typing import List, Tuple
import pytest import pytest
import torch import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
@ -87,6 +88,11 @@ def test_copy_blocks(
block_mapping_tensor = torch.tensor(block_mapping, block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64, dtype=torch.int64,
device=device).view(-1, 2) device=device).view(-1, 2)
opcheck(torch.ops._C_cache_ops.copy_blocks,
(key_caches, value_caches, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
cond=(head_size == HEAD_SIZES[0]))
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
# Run the reference implementation. # Run the reference implementation.
@ -162,6 +168,10 @@ def test_reshape_and_cache(
k_scale = v_scale = 1.0 k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale) kv_cache_dtype, k_scale, v_scale)
@ -269,6 +279,10 @@ def test_reshape_and_cache_flash(
k_scale = v_scale = 1.0 k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache, ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale) slot_mapping, kv_cache_dtype, k_scale, v_scale)
@ -366,6 +380,14 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel. # Call the swap_blocks kernel.
do_opcheck = (head_size == HEAD_SIZES[0])
opcheck(torch.ops._C_cache_ops.swap_blocks,
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
cond=do_opcheck)
opcheck(torch.ops._C_cache_ops.swap_blocks,
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
cond=do_opcheck)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
block_mapping_tensor) block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], ops.swap_blocks(src_value_caches[0], dist_value_caches[0],

View File

@ -7,6 +7,7 @@ from typing import Optional, Type
import pytest import pytest
import torch import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -108,6 +109,9 @@ def cutlass_int8_gemm_helper(m: int,
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33]) @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
@pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024]) @pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
@ -341,6 +345,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol) torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol) torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
if azp_per_token:
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
func_bias))
else:
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
func_bias))
# Test working with a subset of A and B # Test working with a subset of A and B
def test_cutlass_subset(): def test_cutlass_subset():

View File

@ -2,6 +2,7 @@ import pytest
import torch import torch
from tests.kernels.quant_utils import ref_dynamic_per_token_quant from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from tests.kernels.utils import opcheck
from vllm._custom_ops import scaled_int8_quant from vllm._custom_ops import scaled_int8_quant
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -12,6 +13,16 @@ SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
def opcheck_int8_quant(output, input, scale=None):
if scale is not None:
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale))
else:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@ -34,6 +45,8 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
ops_out, ref_out, atol=1, ops_out, ref_out, atol=1,
rtol=0.0) # big atol to account for rounding errors rtol=0.0) # big atol to account for rounding errors
opcheck_int8_quant(ops_out, x)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@ -58,3 +71,5 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
torch.testing.assert_close( torch.testing.assert_close(
out1, out2, atol=1, out1, out2, atol=1,
rtol=0.0) # big atol to account for rounding errors rtol=0.0) # big atol to account for rounding errors
opcheck_int8_quant(out2, x, scale)

View File

@ -1,6 +1,7 @@
import pytest import pytest
import torch import torch
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -52,3 +53,10 @@ def test_rms_norm(
torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2) torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
else: else:
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
if residual is not None:
opcheck(torch.ops._C.fused_add_rms_norm,
(x, residual, layer.weight.data, layer.variance_epsilon))
else:
opcheck(torch.ops._C.rms_norm,
(out, x, layer.weight.data, layer.variance_epsilon))

View File

@ -9,6 +9,7 @@ from typing import Optional, Tuple
import pytest import pytest
import torch import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights) pack_rows, quantize_weights)
@ -76,6 +77,8 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q = w_q.t().contiguous().t() # convert to col major w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype) w_q_machete = ops.machete_prepack_B(w_q, wtype)
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
return w_ref, w_q_machete, w_s, w_zp return w_ref, w_q_machete, w_s, w_zp
@ -146,6 +149,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule=schedule, schedule=schedule,
) )
opcheck(torch.ops._C.machete_gemm,
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
w_zp, w_s), group_size, None, None, None, schedule))
# Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies # Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0 # zeropoints (after scales) causes noise around 0

View File

@ -5,6 +5,7 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
import pytest import pytest
import torch import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
@ -73,12 +74,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
act_order, mnk_factors): act_order, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
# Filter act_order # Filter act_order
if act_order: if act_order:
if group_size == -1: if group_size == -1:
@ -112,6 +110,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm) weight_perm)
opcheck(torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
# Run Marlin repack GPU kernel # Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack( marlin_q_w_2 = ops.gptq_marlin_repack(
q_w_gptq, q_w_gptq,
@ -137,12 +138,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors): mnk_factors):
m_factor, n_factor, k_factor = mnk_factors m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
# Normalize group_size # Normalize group_size
if group_size == -1: if group_size == -1:
group_size = size_k group_size = size_k
@ -165,6 +163,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm) weight_perm)
opcheck(torch.ops._C.awq_marlin_repack,
(q_w_awq, size_k, size_n, quant_type.size_bits))
# Run Marlin repack GPU kernel # Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack( marlin_q_w_2 = ops.awq_marlin_repack(
q_w_awq, q_w_awq,
@ -204,9 +205,6 @@ def test_gptq_marlin_gemm(
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
if act_order: if act_order:
if group_size == -1: if group_size == -1:
return return
@ -224,6 +222,13 @@ def test_gptq_marlin_gemm(
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL) GPTQ_MARLIN_MAX_PARALLEL)
opcheck(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_gemm( output = ops.gptq_marlin_gemm(
a_input, a_input,
marlin_q_w, marlin_q_w,
@ -245,7 +250,6 @@ def test_gptq_marlin_gemm(
torch.cuda.synchronize() torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04 assert max_diff < 0.04
@ -265,9 +269,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k)) a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
@ -279,6 +280,12 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
output_ref = torch.matmul(a_input, w_24_ref) output_ref = torch.matmul(a_input, w_24_ref)
opcheck(torch.ops._C.gptq_marlin_24_gemm,
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
workspace_24.scratch, quant_type, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_24_gemm( output = ops.gptq_marlin_24_gemm(
a_input, a_input,
marlin_24_q_w_comp, marlin_24_q_w_comp,
@ -294,7 +301,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
torch.cuda.synchronize() torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04 assert max_diff < 0.04
@ -321,9 +327,6 @@ def test_fp8_marlin_gemm(
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k), dtype=dtype) a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype) b_weight = rand_data((size_k, size_n), dtype=dtype)
@ -353,6 +356,10 @@ def test_fp8_marlin_gemm(
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL) GPTQ_MARLIN_MAX_PARALLEL)
opcheck(torch.ops._C.fp8_marlin_gemm,
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
output = ops.fp8_marlin_gemm( output = ops.fp8_marlin_gemm(
a=a_input, a=a_input,
b_q_weight=marlin_qweight, b_q_weight=marlin_qweight,
@ -368,7 +375,6 @@ def test_fp8_marlin_gemm(
torch.cuda.synchronize() torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04 assert max_diff < 0.04
@ -396,9 +402,6 @@ def test_awq_marlin_gemm(
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k)) a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
@ -434,7 +437,6 @@ def test_awq_marlin_gemm(
torch.cuda.synchronize() torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04 assert max_diff < 0.04
@ -460,9 +462,6 @@ def test_marlin_qqq_gemm(
size_k = k_chunk * k_factor size_k = k_chunk * k_factor
size_n = n_chunk * n_factor size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k)) a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n)) b_weight = rand_data((size_k, size_n))
@ -479,6 +478,11 @@ def test_marlin_qqq_gemm(
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N, workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
MARLIN_QQQ_MAX_PARALLEL) MARLIN_QQQ_MAX_PARALLEL)
opcheck(torch.ops._C.marlin_qqq_gemm,
(q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]))
output = ops.marlin_qqq_gemm( output = ops.marlin_qqq_gemm(
q_a, q_a,
marlin_qqq_q_w, marlin_qqq_q_w,
@ -495,6 +499,5 @@ def test_marlin_qqq_gemm(
torch.cuda.synchronize() torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref) max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04 assert max_diff < 0.04

View File

@ -3,7 +3,8 @@
import itertools import itertools
import random import random
from numbers import Number from numbers import Number
from typing import Any, List, NamedTuple, Optional, Tuple, Union from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
Union)
import pytest import pytest
import torch import torch
@ -13,6 +14,21 @@ from vllm.attention.backends.xformers import XFormersBackend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad) make_tensor_with_pad)
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
)
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
)
class QKVInputs(NamedTuple): class QKVInputs(NamedTuple):
''' '''
@ -926,3 +942,19 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
ideal_output = test_params.packed_qkvo.ideal_output ideal_output = test_params.packed_qkvo.ideal_output
torch.testing.assert_close(ideal_output, torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output)) output_under_test.view_as(ideal_output))
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
torch._library.custom_ops.CustomOpDef],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
raise_exception: bool = True,
cond: bool = True) -> Dict[str, str]:
return torch.library.opcheck(
op,
args,
kwargs,
test_utils=test_utils,
raise_exception=raise_exception) if cond else {}

View File

@ -7,26 +7,6 @@ import pytest
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
# In this test we hardcode prompts and generations for the model so we don't
# need to require the AQLM package as a dependency
example_prompts = [
'vLLM is a high-throughput and memory-efficient inference and serving '
'engine for LLMs.\n',
'Briefly describe the major milestones in the development of artificial '
'intelligence from 1950 to 2020.\n',
'Compare and contrast artificial intelligence with human intelligence in '
'terms of processing information.\n',
'Describe the basic components of a neural network and how it can be '
'trained.\n',
'Write a short story about a robot that dreams for the first time.\n',
'Analyze the impact of the COVID-19 pandemic on global economic structures '
'and future business models.\n',
'Explain the cultural significance of the Mona Lisa painting, and how its '
'perception might vary in Western versus Eastern societies.\n',
"Translate the following English sentence into Japanese, French, and "
"Swahili: 'The early bird catches the worm.'\n"
]
# These ground truth generations were generated using `transformers==4.38.1 # These ground truth generations were generated using `transformers==4.38.1
# aqlm==1.1.0 torch==2.2.0` # aqlm==1.1.0 torch==2.2.0`
# and the below code: # and the below code:

View File

@ -204,6 +204,22 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_g_idx, use_exllama, bit) b_g_idx, use_exllama, bit)
# TODO: has to be a better way to do this
try:
torch.ops._C.gptq_gemm # noqa B018
@torch.library.register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
use_exllama: bool, bit: int) -> torch.Tensor:
return torch.empty((a.size(0), b_q_weight.size(1)),
dtype=a.dtype,
device=a.device)
except Exception:
pass
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None: bit: int) -> None:
torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
@ -227,6 +243,194 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n, size_k) size_n, size_k)
# TODO: has to be a better way to do this
try:
torch.ops._C.gptq_marlin_24_gemm # noqa B018
@torch.library.register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
b_zeros: torch.Tensor,
g_idx: torch.Tensor,
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
) -> torch.Tensor:
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
@torch.library.register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@torch.library.register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@torch.library.register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
in_c = qweight.size(0)
qout_c = qweight.size(1)
out_c = qout_c * 8
return torch.empty((in_c, out_c),
dtype=scales.dtype,
device=scales.device)
@torch.library.register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor:
num_in_feats = input.size(0)
return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
dtype=input.dtype,
device=input.device).sum(0)
@torch.library.register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
bias: Optional[torch.Tensor]) -> torch.Tensor:
out_features = codes.size(0) * codebooks.size(2)
flat_input = input.reshape((-1, input.size(-1)))
flat_output = torch.empty((flat_input.size(0), out_features),
dtype=input.dtype,
device=input.device)
output_sizes = list(input.shape)
output_sizes.pop()
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))
@torch.library.register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
in_features = codes.size(1) * 8
out_features = codes.size(0)
return torch.empty((out_features, in_features),
dtype=codebooks.dtype,
device=codebooks.device)
@torch.library.register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@torch.library.register_fake("_C::machete_gemm")
def machete_gemm_fake(
a: torch.Tensor,
b_q: torch.
Tensor, # Should be the tensor returned by machete_prepack_B
b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None,
b_group_size: Optional[int] = None,
c: Optional[torch.Tensor] = None,
alpha: Optional[float] = None,
beta: Optional[float] = None,
schedule: Optional[str] = None,
) -> torch.Tensor:
m = a.size(0)
n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype)
@torch.library.register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight)
@torch.library.register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
seq_idx_: Optional[torch.Tensor],
initial_states_: Optional[torch.Tensor],
final_states_out_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)
@torch.library.register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, index_: Optional[torch.Tensor],
x: Optional[torch.Tensor]) -> List[torch.Tensor]:
a = torch.empty_like(u)
if x is not None:
b = x
else:
b = torch.empty((u.size(0), u.size(1), A.size(1)),
dtype=u.dtype,
device=u.device)
if z_ is not None:
c = torch.empty_like(z_)
return [a, b, c]
else:
return [a, b]
except Exception:
pass
# cutlass # cutlass
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)

View File

@ -203,6 +203,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")), ("true", "1")),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": "LOCAL_RANK":

View File

@ -733,7 +733,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
indices_for_current_run: List[int]): indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks # move out all of the occupied but currently not running blocks
# outside of the first n blocks # outside of the first n blocks
destination_indices = set([range(batch_size)]) destination_indices = range(batch_size)
max_possible_batch_size = self.mamba_cache[0].shape[1] max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices: for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \ if destination_index in self._get_all_occupied_indices() and \

View File

@ -75,6 +75,10 @@ _NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
# For now, bump up cache limits for recompilations during CUDA graph warmups.
torch._dynamo.config.cache_size_limit = 128
torch._dynamo.config.accumulated_cache_size_limit = 128
@dataclass(frozen=True) @dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase): class ModelInputForGPU(ModelRunnerInputBase):
@ -1060,8 +1064,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"This may lead to less accurate results!") "This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
self.model = torch.compile(self.model, self.model = torch.compile(
fullgraph=True, self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend="eager") backend="eager")
def save_sharded_state( def save_sharded_state(

View File

@ -166,6 +166,7 @@ class Worker(LocalOrDistributedWorkerBase):
torch.cuda.set_device(self.device) torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype) _check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0] self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else: else: