diff --git a/csrc/ops.h b/csrc/ops.h index 567d9fae4b..4952e826ec 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, #endif -void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, - float scale); +void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, + torch::Tensor const& scale); void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 4902e4c234..11baa5d414 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -28,9 +28,10 @@ namespace vllm { template __global__ void static_scaled_int8_quant_kernel( const scalar_t* __restrict__ input, int8_t* __restrict__ out, - scale_type scale, const int hidden_size) { + const scale_type* scale_ptr, const int hidden_size) { const int tid = threadIdx.x; const int token_idx = blockIdx.x; + scale_type scale = *scale_ptr; for (int i = tid; i < hidden_size; i += blockDim.x) { out[token_idx * hidden_size + i] = @@ -39,11 +40,13 @@ __global__ void static_scaled_int8_quant_kernel( } } // namespace vllm -void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - float scale) { +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& scale) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scale.numel() == 1); + int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { vllm::static_scaled_int8_quant_kernel <<>>(input.data_ptr(), - out.data_ptr(), scale, - hidden_size); + out.data_ptr(), + scale.data_ptr(), hidden_size); }); } diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index b9aa00ce13..29890118c9 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, torch.iinfo(torch.int8).min, torch.iinfo(torch.int8).max).to(torch.int8) out2 = torch.empty_like(x, dtype=torch.int8) - ops.static_scaled_int8_quant(out2, x, scale) + scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") + + ops.static_scaled_int8_quant(out2, x, scale_argument) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 22cf5a44e3..8a6f6d96d8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -265,7 +265,7 @@ def scaled_fp8_quant( # int8 def static_scaled_int8_quant(input: torch.Tensor, - scale: float) -> torch.Tensor: + scale: torch.Tensor) -> torch.Tensor: """ Quantize the input tensor to int8 and return the quantized tensor. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py index 7e3e932cfe..2dfc6e2b07 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py @@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): act_scale = layer.input_scale # Input quantize - x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item()) + x_q = custom_ops.static_scaled_int8_quant(x, act_scale) return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, weight_scale, x.dtype)