mirror of https://github.com/vllm-project/vllm
Fix overflow in awq kernel (#1295)
Co-authored-by: 楚天翔 <tianxiang.ctx@alibaba-inc.com>
This commit is contained in:
parent
8285736840
commit
980dd4a2c4
|
@ -90,7 +90,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
|||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 128
|
||||
+ ((int)threadIdx.y) * 64
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
@ -323,7 +323,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
|
|||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 64
|
||||
+ ((int)threadIdx.y) * 32
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
|
Loading…
Reference in New Issue