Fix overflow in awq kernel (#1295)

Co-authored-by: 楚天翔 <tianxiang.ctx@alibaba-inc.com>
This commit is contained in:
CHU Tianxiang 2023-10-11 15:19:53 +08:00 committed by GitHub
parent 8285736840
commit 980dd4a2c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -90,7 +90,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
+ (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2;
@ -323,7 +323,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;