Faster kernels for quantized matmul on cuda (#2060)

* Hook the quantized matmul cuda kernels.

* Add a (currently broken) test.

* Kernel fixes.

* Fix by transposing the rhs matrix.

* Add the q4-1 kernels.

* Proper block sizes.

* More details in the tests.
This commit is contained in:
Laurent Mazare 2024-04-15 08:32:47 +02:00 committed by GitHub
parent c119600d6e
commit f7d5bf5b97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 255 additions and 17 deletions

View File

@ -40,6 +40,7 @@ fn quantize_q8_1(
src: &CudaView<f32>,
dst: &mut CudaSlice<u8>,
elem_count: usize,
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
@ -49,7 +50,7 @@ fn quantize_q8_1(
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
@ -180,7 +181,7 @@ fn mul_mat_vec_via_q8_1(
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
quantize_q8_1(y, &mut y_q8_1, ncols, 1, dev)?;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
@ -216,6 +217,75 @@ fn mul_mat_vec_via_q8_1(
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn mul_mat_via_q8_1(
data: &CudaSlice<u8>,
y: &CudaView<f32>,
dtype: GgmlDType,
x_rows: usize,
x_cols: usize,
y_rows: usize,
y_cols: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
}
if y.len() != y_rows * y_cols {
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
}
if x_cols != y_rows {
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
}
let k = x_cols;
// Start by quantizing y
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
ceil_div(y_cols, mmq_x) as u32,
1,
),
block_dim: (WARP_SIZE as u32, 4, 1),
shared_mem_bytes: 0,
};
let params = (
/* vx */ data,
/* vy */ &y_q8_1,
/* dst */ &dst,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32,
);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
@ -373,9 +443,30 @@ impl QCudaStorage {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
} else {
let storage = storage.as_cuda_slice::<f32>()?;
let storage = match layout.contiguous_offsets() {
Some((o1, o2)) => storage.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous {
op: "quantized-matmul",
}
.bt())?,
};
mul_mat_via_q8_1(
&self.data,
&storage,
self.dtype,
/* x_rows */ n,
/* x_cols */ k,
/* y_rows */ k,
/* y_cols */ m,
self.device(),
)?
};
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
@ -412,7 +503,7 @@ mod test {
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
@ -453,4 +544,44 @@ mod test {
assert_eq!(vs[0], 5561851.0);
Ok(())
}
#[test]
fn cuda_mm_q8_1() -> Result<()> {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
&xs.data,
&y.slice(..),
/* dtype */ GgmlDType::Q4_0,
/* x_rows */ 4,
/* x_cols */ ncols,
/* y_rows */ ncols,
/* y_cols */ 4,
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
x @ x.t() / 16
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
*/
assert_eq!(vs.len(), 16);
assert_eq!(vs[0], 347604.0);
assert_eq!(vs[1], 888153.06);
assert_eq!(vs[4], 869780.7);
assert_eq!(vs[5], 2483145.0);
assert_eq!(vs[11], 9407368.0);
assert_eq!(vs[14], 9470856.0);
assert_eq!(vs[15], 13138824.0);
Ok(())
}
}

View File

@ -71,8 +71,6 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
}
#define CUDA_USE_TENSOR_CORES
#define WARP_SIZE 32
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
@ -103,6 +101,25 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
#define MMQ_Y_Q4_0_PASCAL 64
#define NWARPS_Q4_0_PASCAL 8
#define MMQ_X_Q4_1_RDNA2 64
#define MMQ_Y_Q4_1_RDNA2 128
#define NWARPS_Q4_1_RDNA2 8
#define MMQ_X_Q4_1_RDNA1 64
#define MMQ_Y_Q4_1_RDNA1 64
#define NWARPS_Q4_1_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q4_1_AMPERE 4
#define MMQ_Y_Q4_1_AMPERE 32
#define NWARPS_Q4_1_AMPERE 4
#else
#define MMQ_X_Q4_1_AMPERE 64
#define MMQ_Y_Q4_1_AMPERE 128
#define NWARPS_Q4_1_AMPERE 4
#endif
#define MMQ_X_Q4_1_PASCAL 64
#define MMQ_Y_Q4_1_PASCAL 64
#define NWARPS_Q4_1_PASCAL 8
#define MMQ_X_Q5_0_RDNA2 64
#define MMQ_Y_Q5_0_RDNA2 128
#define NWARPS_Q5_0_RDNA2 8
@ -558,6 +575,52 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
GGML_CUDA_ASSUME(i_offset >= 0);
GGML_CUDA_ASSUME(i_offset < nwarps);
GGML_CUDA_ASSUME(k >= 0);
GGML_CUDA_ASSUME(k < WARP_SIZE);
const int kbx = k / QI4_1;
const int kqsx = k % QI4_1;
const block_q4_1 * bx0 = (const block_q4_1 *) vx;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
i = min(i, i_max);
}
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
}
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
if (need_check) {
i = min(i, i_max);
}
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
}
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
(void)x_qh; (void)x_sc;
@ -568,6 +631,16 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(
*x_dm = (half2 *) tile_x_d;
}
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
__shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y];
__shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
*x_ql = tile_x_qs;
*x_dm = tile_x_dm;
}
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
@ -3493,6 +3566,26 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
}
static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
int u[2*VDR_Q4_1_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
}
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
(&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
}
extern "C" __global__ void
mul_mat_q4_0(
@ -3503,10 +3596,24 @@ extern "C" __global__ void
const int nwarps = NWARPS_Q4_0_AMPERE;
mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
load_tiles_q4_0<mmq_y, nwarps, false>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
load_tiles_q4_0<mmq_y, nwarps, true>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
extern "C" __global__ void
mul_mat_q4_1(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
const int mmq_x = MMQ_X_Q4_1_AMPERE;
const int mmq_y = MMQ_Y_Q4_1_AMPERE;
const int nwarps = NWARPS_Q4_1_AMPERE;
mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
load_tiles_q4_1<mmq_y, nwarps, true>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
extern "C" __global__ void
mul_mat_q5_0(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
@ -3516,7 +3623,7 @@ extern "C" __global__ void
const int nwarps = NWARPS_Q5_0_AMPERE;
mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
load_tiles_q5_0<mmq_y, nwarps, false>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
load_tiles_q5_0<mmq_y, nwarps, true>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
@ -3529,7 +3636,7 @@ mul_mat_q5_1(
const int nwarps = NWARPS_Q5_1_AMPERE;
mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
load_tiles_q5_1<mmq_y, nwarps, false>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
load_tiles_q5_1<mmq_y, nwarps, true>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
@ -3542,7 +3649,7 @@ extern "C" __global__ void
const int nwarps = NWARPS_Q8_0_AMPERE;
mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
load_tiles_q8_0<mmq_y, nwarps, false>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
load_tiles_q8_0<mmq_y, nwarps, true>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
@ -3554,7 +3661,7 @@ mul_mat_q2_K(
const int mmq_y = MMQ_Y_Q2_K_AMPERE;
const int nwarps = NWARPS_Q2_K_AMPERE;
mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
load_tiles_q2_K<mmq_y, nwarps, false>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
load_tiles_q2_K<mmq_y, nwarps, true>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
@ -3566,7 +3673,7 @@ extern "C" __global__ void
const int mmq_y = MMQ_Y_Q3_K_AMPERE;
const int nwarps = NWARPS_Q3_K_AMPERE;
mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
load_tiles_q3_K<mmq_y, nwarps, false>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
load_tiles_q3_K<mmq_y, nwarps, true>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
@ -3578,7 +3685,7 @@ extern "C" __global__ void
const int mmq_y = MMQ_Y_Q4_K_AMPERE;
const int nwarps = NWARPS_Q4_K_AMPERE;
mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
load_tiles_q4_K<mmq_y, nwarps, false>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
load_tiles_q4_K<mmq_y, nwarps, true>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
@ -3590,7 +3697,7 @@ mul_mat_q5_K(
const int mmq_y = MMQ_Y_Q5_K_AMPERE;
const int nwarps = NWARPS_Q5_K_AMPERE;
mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
load_tiles_q5_K<mmq_y, nwarps, false>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
load_tiles_q5_K<mmq_y, nwarps, true>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}
@ -3602,6 +3709,6 @@ extern "C" __global__ void
const int mmq_y = MMQ_Y_Q6_K_AMPERE;
const int nwarps = NWARPS_Q6_K_AMPERE;
mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
load_tiles_q6_K<mmq_y, nwarps, false>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
load_tiles_q6_K<mmq_y, nwarps, true>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
}