[Kernel] Pass a device pointer into the quantize kernel for the scales (#5159)

This commit is contained in:
Tyler Michael Smith 2024-06-03 12:52:30 -04:00 committed by GitHub
parent 0ab278ca31
commit cbb2f59cc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 16 additions and 11 deletions

View File

@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
float scale);
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);

View File

@ -28,9 +28,10 @@ namespace vllm {
template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
scale_type scale, const int hidden_size) {
const scale_type* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
scale_type scale = *scale_ptr;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
@ -40,10 +41,12 @@ __global__ void static_scaled_int8_quant_kernel(
} // namespace vllm
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& scale) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(), scale,
hidden_size);
out.data_ptr<int8_t>(),
scale.data_ptr<float>(), hidden_size);
});
}

View File

@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
ops.static_scaled_int8_quant(out2, x, scale)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
ops.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors

View File

@ -265,7 +265,7 @@ def scaled_fp8_quant(
# int8
def static_scaled_int8_quant(input: torch.Tensor,
scale: float) -> torch.Tensor:
scale: torch.Tensor) -> torch.Tensor:
"""
Quantize the input tensor to int8 and return the quantized tensor.

View File

@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
act_scale = layer.input_scale
# Input quantize
x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item())
x_q = custom_ops.static_scaled_int8_quant(x, act_scale)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
weight_scale, x.dtype)