mirror of https://github.com/vllm-project/vllm
AWQ: Up to 2.66x higher throughput (#2566)
This commit is contained in:
parent
390b495ff3
commit
beb89f68b4
|
@ -70,6 +70,14 @@ torch::Tensor awq_gemm(
|
|||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
|
|
|
@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
#ifndef USE_ROCM
|
||||
// Quantization ops
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
#endif
|
||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
|
|
|
@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
|
|||
#endif
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(64) dequantize_weights(
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
half* __restrict__ C,
|
||||
int G
|
||||
)
|
||||
{
|
||||
int j_factors1 = 4;
|
||||
int row_stride2 = 4;
|
||||
int split_k_iters = 1;
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
half B_shared[32 * (128 + 8)];
|
||||
|
||||
half* B_shared_ptr2 = B_shared;
|
||||
|
||||
half B_shared_warp[32];
|
||||
int OC = 512;
|
||||
|
||||
int N = blockDim.x * gridDim.x; // 2
|
||||
int col = (blockIdx.x * blockDim.x + threadIdx.x);
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int index1 = 8 * col + 8 * row * N;
|
||||
half* C_ptr2 = C + index1;
|
||||
|
||||
int index2 = col + row * N;
|
||||
int* B_ptr2 = B + index2;
|
||||
|
||||
int index3 = col + (int)(row / G) * N;
|
||||
int* zeros_ptr2 = zeros + index3;
|
||||
int index4 = 8 * col + (int)(row / G) * N * 8;
|
||||
half* scaling_factors_ptr2 = scaling_factors + index4;
|
||||
|
||||
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
|
||||
int j=0;
|
||||
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j);
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
|
||||
*(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16;
|
||||
|
||||
for (int i=0; i<8; ++i) {
|
||||
*(C_ptr2 + i) = B_shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy)
|
||||
{
|
||||
int in_c = _kernel.size(0);
|
||||
int qout_c = _kernel.size(1);
|
||||
int out_c = qout_c * 8;
|
||||
int G = in_c / _scaling_factors.size(0);
|
||||
|
||||
int x_thread = thx;
|
||||
int y_thread = thy;
|
||||
|
||||
int x_blocks = 1;
|
||||
int y_blocks = 1;
|
||||
if (thx==0) {
|
||||
x_thread = qout_c;
|
||||
}
|
||||
if (thy==0) {
|
||||
y_thread = in_c;
|
||||
}
|
||||
if (thx==0 && thy==0) {
|
||||
x_thread = 8;
|
||||
y_thread = 8;
|
||||
x_blocks = (int)(qout_c / 8);
|
||||
y_blocks = (int)(in_c / 8);
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
|
||||
dim3 num_blocks(x_blocks, y_blocks);
|
||||
dim3 threads_per_block(x_thread, y_thread);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
kernel, scaling_factors, zeros, de_kernel, G);
|
||||
|
||||
return _de_kernel;
|
||||
}
|
||||
|
||||
// in_feats: M, IC [float16]
|
||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||
// scaling_factors: IC // G, OC [float16]
|
||||
|
|
|
@ -153,7 +153,16 @@ class AWQLinearMethod(LinearMethodBase):
|
|||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
|
||||
|
||||
# num_tokens >= threshold
|
||||
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
|
||||
|
||||
if FP16_MATMUL_HEURISTIC_CONDITION:
|
||||
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
||||
out = torch.matmul(reshaped_x, out)
|
||||
else:
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
||||
pack_factor)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
|
|
Loading…
Reference in New Issue