diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index ad207461..6e078a6e 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -5,25 +5,32 @@ extern crate accelerate_src; extern crate intel_mkl_src; use anyhow::Result; -use candle_core::{Device, Tensor}; +use candle_core::{Device, Module, Tensor}; + +use candle_core::quantized::{QMatMul, QTensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; - let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?; - let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?; - let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?; - println!("{out_t}"); - let in_t = in_t.to_device(&Device::Cpu)?; - let k_t = k_t.to_device(&Device::Cpu)?; - let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?; - let diff = (out_t.to_device(&Device::Cpu)? - out_t2)? - .sqr()? - .sum_all()?; - println!("{diff}"); + let q = Tensor::randn(0f32, 1.0, (72, 32), &device)?; + let q_cpu = q.to_device(&Device::Cpu)?; + let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q4_0)?; + let q = QMatMul::from_qtensor(q)?; + let x = Tensor::randn(0f32, 1.0, (5, 32), &device)?; + let res_q_cuda = q.forward(&x)?; + println!("{res_q_cuda}"); - let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?; - let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?; - let res = t.conv2d(&w, 1, 1, 1, 1)?; - println!("{res:?}"); + let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q4_0)?; + let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?; + let q_cpu = QMatMul::from_qtensor(q_cpu)?; + let x_cpu = x.to_device(&Device::Cpu)?; + let res_q_cpu = q_cpu.forward(&x_cpu)?; + println!("{res_q_cpu}"); + + let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?; + let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))? + .abs()? + .flatten_all()? + .max(0)?; + println!("{diff}"); Ok(()) } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c19d7c56..959f0f31 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -827,9 +827,9 @@ impl BackendStorage for MetalStorage { layout.start_offset() * self.dtype.size_in_bytes(), ), &t.buffer, - (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), + (t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, - (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), + (f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), &buffer, ) .map_err(MetalError::from)?; @@ -1264,7 +1264,7 @@ impl BackendStorage for MetalStorage { let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; - blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length); + blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); blit.end_encoding(); } else { let src_shape = src_l.shape(); @@ -1636,7 +1636,7 @@ impl BackendDevice for MetalDevice { min as f32, max as f32, shape.elem_count(), - &*self.seed.lock().unwrap(), + &self.seed.lock().unwrap(), &buffer, ) .map_err(MetalError::from)?; @@ -1667,7 +1667,7 @@ impl BackendDevice for MetalDevice { mean as f32, stddev as f32, shape.elem_count(), - &*self.seed.lock().unwrap(), + &self.seed.lock().unwrap(), &buffer, ) .map_err(MetalError::from)?; diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs new file mode 100644 index 00000000..a2fc6655 --- /dev/null +++ b/candle-core/src/quantized/cuda.rs @@ -0,0 +1,321 @@ +use super::{GgmlDType, QStorage}; +use crate::{backend::BackendDevice, cuda_backend::WrapErr}; +use crate::{CudaDevice, CudaStorage, Result}; + +use cudarc::driver::{CudaSlice, DeviceSlice}; + +pub struct QCudaStorage { + data: CudaSlice, + dtype: GgmlDType, + device: CudaDevice, +} + +pub const WARP_SIZE: usize = 32; +pub const MMQ_X_Q4_0_AMPERE: usize = 4; +pub const MMQ_Y_Q4_0_AMPERE: usize = 32; +pub const NWARPS_Q4_0_AMPERE: usize = 4; +pub const GGML_CUDA_MMV_X: usize = 32; +pub const GGML_CUDA_MMV_Y: usize = 1; + +fn dequantize( + data: &CudaSlice, + dtype: GgmlDType, + elem_count: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let (kernel_name, is_k) = match dtype { + GgmlDType::Q4_0 => ("dequantize_block_q4_0", false), + GgmlDType::Q4_1 => ("dequantize_block_q4_1", false), + GgmlDType::Q5_0 => ("dequantize_block_q5_0", false), + GgmlDType::Q5_1 => ("dequantize_block_q5_1", false), + GgmlDType::Q8_0 => ("dequantize_block_q8_0", false), + GgmlDType::Q2K => ("dequantize_block_q2_K", true), + GgmlDType::Q3K => ("dequantize_block_q3_K", true), + GgmlDType::Q4K => ("dequantize_block_q4_K", true), + GgmlDType::Q5K => ("dequantize_block_q5_K", true), + GgmlDType::Q6K => ("dequantize_block_q6_K", true), + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let dst = dev.alloc_zeros::(elem_count).w()?; + let nb = (elem_count + 255) / 256; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (nb as u32, 1, 1), + block_dim: (32, 1, 1), + shared_mem_bytes: 0, + }; + + if is_k { + let params = (data, &dst); + unsafe { func.launch(cfg, params) }.w()?; + } else { + let nb32 = elem_count / 32; + let params = (data, &dst, nb32 as i32); + unsafe { func.launch(cfg, params) }.w()?; + } + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + +fn dequantize_mut_mal_vec( + data: &CudaSlice, + y: &cudarc::driver::CudaView, + dtype: GgmlDType, + ncols: usize, + nrows: usize, + dev: &CudaDevice, +) -> Result { + use cudarc::driver::LaunchAsync; + + let kernel_name = match dtype { + GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda", + GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda", + GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda", + GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda", + GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda", + GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k", + GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k", + GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k", + GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k", + GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k", + _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), + }; + let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?; + let dst = dev.alloc_zeros::(nrows).w()?; + let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (block_num_y as u32, 1, 1), + block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1), + shared_mem_bytes: 0, + }; + + let params = (data, y, &dst, ncols as i32, nrows as i32); + unsafe { func.launch(cfg, params) }.w()?; + Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone())) +} + +impl QCudaStorage { + pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result { + let size_in_bytes = el_count * dtype.type_size() / dtype.block_size(); + let data = device.alloc_zeros::(size_in_bytes).w()?; + Ok(QCudaStorage { + data, + device: device.clone(), + dtype, + }) + } + + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn device(&self) -> &CudaDevice { + &self.device + } + + pub fn dequantize(&self, elem_count: usize) -> Result { + let fast_kernel = match self.dtype { + GgmlDType::Q4_0 + | GgmlDType::Q4_1 + | GgmlDType::Q5_0 + | GgmlDType::Q5_1 + | GgmlDType::Q8_0 + | GgmlDType::Q2K + | GgmlDType::Q3K + | GgmlDType::Q4K + | GgmlDType::Q5K + | GgmlDType::Q6K => true, + _ => false, + }; + if fast_kernel { + return dequantize(&self.data, self.dtype, elem_count, self.device()); + } + // Run the dequantization on cpu. + use crate::quantized::k_quants::GgmlType; + + let buffer = self.device.dtoh_sync_copy(&self.data).w()?; + let mut out = vec![0.0; elem_count]; + let block_len = elem_count / self.dtype.block_size(); + match self.dtype { + GgmlDType::F32 => { + let slice = + unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) }; + out.copy_from_slice(slice) + } + GgmlDType::F16 => { + let vec: Vec = read_to_vec(&buffer, block_len); + half::f16::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_0 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q4_1 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_0 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q5_1 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_0 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; + } + GgmlDType::Q8_1 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; + } + GgmlDType::Q2K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; + } + GgmlDType::Q3K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; + } + GgmlDType::Q4K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; + } + GgmlDType::Q5K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; + } + GgmlDType::Q6K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; + } + GgmlDType::Q8K => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; + } + } + + self.device + .storage_from_cpu_storage(&crate::CpuStorage::F32(out)) + } + + pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> { + // Run the quantization on cpu. + let src = match &src.slice { + crate::cuda_backend::CudaStorageSlice::F32(data) => { + self.device.dtoh_sync_copy(data).w()? + } + _ => crate::bail!("only f32 can be quantized"), + }; + let src_len = src.len(); + let src = crate::Storage::Cpu(crate::CpuStorage::F32(src)); + let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; + qcpu_storage.quantize(&src)?; + let data = qcpu_storage.data()?; + let data = self.device.htod_sync_copy(data.as_ref()).w()?; + self.data = data; + Ok(()) + } + + pub fn storage_size_in_bytes(&self) -> usize { + self.data.len() + } + + pub fn fwd( + &self, + self_shape: &crate::Shape, + storage: &CudaStorage, + layout: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + let dmmv = match layout.shape().dims() { + [1, 1, _] | [1, _] => true, + _ => false, + }; + if dmmv { + self.dequantize_matmul_vec(self_shape, storage, layout) + } else { + self.dequantize_matmul(self_shape, storage, layout) + } + } +} + +impl QCudaStorage { + fn dequantize_matmul_vec( + &self, + self_shape: &crate::Shape, + rhs: &CudaStorage, + rhs_l: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + let (nrows, ncols) = self_shape.dims2()?; + let rhs = rhs.as_cuda_slice::()?; + let rhs = match rhs_l.contiguous_offsets() { + Some((o1, o2)) => rhs.slice(o1..o2), + None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?, + }; + let (with_batch, k) = match rhs_l.shape().dims() { + [1, 1, k] => (true, k), + [1, k] => (false, k), + _ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()), + }; + if ncols != *k { + crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape()) + } + + let out = + dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?; + let out_shape = if with_batch { + vec![1, 1, nrows] + } else { + vec![1, nrows] + }; + Ok((out, out_shape.into())) + } + + fn dequantize_matmul( + &self, + self_shape: &crate::Shape, + storage: &CudaStorage, + layout: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + use crate::backend::BackendStorage; + let (n, k) = self_shape.dims2()?; + let (b, m, k2) = match layout.shape().dims() { + &[b, m, k2] => (b, m, k2), + &[m, k2] => (1, m, k2), + s => crate::bail!("unexpected shape for input {s:?}"), + }; + if k2 != k { + crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape()) + } + + let data_f32 = self.dequantize(n * k)?; + let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0); + let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?; + let mut out_shape = layout.shape().dims().to_vec(); + out_shape.pop(); + out_shape.push(n); + Ok((out, out_shape.into())) + } +} + +fn read_to_vec(buffer: &[u8], n: usize) -> Vec { + let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) }; + slice.to_vec() +} + +pub fn load_quantized( + device: &CudaDevice, + data: &[T], +) -> Result { + let data = unsafe { + std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data)) + }; + let data = device.htod_sync_copy(data).w()?; + Ok(QStorage::Cuda(QCudaStorage { + data, + device: device.clone(), + dtype: T::DTYPE, + })) +} diff --git a/candle-core/src/quantized/dummy_cuda.rs b/candle-core/src/quantized/dummy_cuda.rs new file mode 100644 index 00000000..598c5cd1 --- /dev/null +++ b/candle-core/src/quantized/dummy_cuda.rs @@ -0,0 +1,50 @@ +#![allow(unused)] +use super::GgmlDType; +use crate::{CudaDevice, CudaStorage, Error, Result}; + +pub struct QCudaStorage { + dtype: GgmlDType, + device: CudaDevice, +} + +impl QCudaStorage { + pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn dtype(&self) -> GgmlDType { + self.dtype + } + + pub fn device(&self) -> &CudaDevice { + &self.device + } + + pub fn dequantize(&self, _elem_count: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> { + Err(Error::NotCompiledWithCudaSupport) + } + + pub fn storage_size_in_bytes(&self) -> usize { + 0 + } + + pub fn fwd( + &self, + _self_shape: &crate::Shape, + _storage: &CudaStorage, + _layout: &crate::Layout, + ) -> Result<(CudaStorage, crate::Shape)> { + Err(Error::NotCompiledWithCudaSupport) + } +} + +pub fn load_quantized( + _device: &CudaDevice, + _data: &[T], +) -> Result { + Err(Error::NotCompiledWithCudaSupport) +} diff --git a/candle-core/src/quantized/dummy_metal.rs b/candle-core/src/quantized/dummy_metal.rs index 96f91c50..520d0ed4 100644 --- a/candle-core/src/quantized/dummy_metal.rs +++ b/candle-core/src/quantized/dummy_metal.rs @@ -41,3 +41,10 @@ impl QMetalStorage { Err(Error::NotCompiledWithMetalSupport) } } + +pub fn load_quantized( + _device: &MetalDevice, + _data: &[T], +) -> Result { + Err(Error::NotCompiledWithMetalSupport) +} diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index e6f5791c..99200bbd 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -1,7 +1,5 @@ //! Support for the GGML file format. -#[cfg(feature = "metal")] -use super::metal::load_quantized_metal; use super::{k_quants, GgmlDType, QStorage}; use crate::{Device, Result}; use byteorder::{LittleEndian, ReadBytesExt}; @@ -130,13 +128,8 @@ fn from_raw_data( let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; let data: QStorage = match device { Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())), - #[cfg(feature = "metal")] - Device::Metal(metal) => load_quantized_metal(metal, data)?, - #[cfg(not(feature = "metal"))] - Device::Metal(_metal) => { - crate::bail!("Metal backend requires `metal` feature") - } - device => unimplemented!("Implement quantized tensor for device {device:?}"), + Device::Metal(metal) => super::metal::load_quantized(metal, data)?, + Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?, }; super::QTensor::new(data, dims) } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 5cdfe6ab..af1cf369 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -34,6 +34,8 @@ impl QMetalStorage { } pub fn dequantize(&self, elem_count: usize) -> Result { + use crate::quantized::k_quants::GgmlType; + let buffer = self.device.new_buffer_managed(self.buffer.length())?; let command_buffer = self.device.command_buffer()?; command_buffer.set_label("to_cpu"); @@ -43,81 +45,62 @@ impl QMetalStorage { blit.end_encoding(); self.device.wait_until_completed()?; let mut out = vec![0.0; elem_count]; + let block_len = elem_count / self.dtype.block_size(); match self.dtype { GgmlDType::F32 => { - let vec: Vec = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); f32::to_float(&vec, &mut out)?; } GgmlDType::F16 => { - let vec: Vec = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); half::f16::to_float(&vec, &mut out)?; } GgmlDType::Q4_0 => { - let vec: Vec = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; } GgmlDType::Q4_1 => { - let vec: Vec = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; } GgmlDType::Q5_0 => { - let vec: Vec = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; } GgmlDType::Q5_1 => { - let vec: Vec = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; } GgmlDType::Q8_0 => { - let vec: Vec = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; } GgmlDType::Q8_1 => { - let vec: Vec = read_to_vec(&buffer, elem_count); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; } GgmlDType::Q2K => { - let vec: Vec = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; } GgmlDType::Q3K => { - let vec: Vec = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; } GgmlDType::Q4K => { - let vec: Vec = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; } GgmlDType::Q5K => { - let vec: Vec = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; } GgmlDType::Q6K => { - let vec: Vec = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; } GgmlDType::Q8K => { - let vec: Vec = - read_to_vec(&buffer, elem_count / self.dtype.block_size()); - use crate::quantized::k_quants::GgmlType; + let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; } } @@ -192,7 +175,7 @@ impl QMetalStorage { } } -pub fn load_quantized_metal( +pub fn load_quantized( device: &MetalDevice, data: &[T], ) -> Result { diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index d14b2dc2..f7abcd93 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; #[cfg(target_feature = "avx")] pub mod avx; +mod dummy_cuda; mod dummy_metal; pub mod ggml_file; pub mod gguf_file; @@ -14,6 +15,13 @@ pub mod metal; mod metal { pub use super::dummy_metal::*; } +#[cfg(feature = "cuda")] +pub mod cuda; +#[cfg(not(feature = "cuda"))] +mod cuda { + pub use super::dummy_cuda::*; +} + #[cfg(target_feature = "neon")] pub mod neon; #[cfg(target_feature = "simd128")] @@ -39,8 +47,9 @@ impl Device { let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?; Ok(QStorage::Metal(storage)) } - Device::Cuda(_cuda) => { - crate::bail!("Cuda ggml quantization not supported"); + Device::Cuda(cuda) => { + let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?; + Ok(QStorage::Cuda(storage)) } } } @@ -49,6 +58,7 @@ impl Device { pub enum QStorage { Cpu(Box), Metal(metal::QMetalStorage), + Cuda(cuda::QCudaStorage), } impl QStorage { @@ -56,6 +66,7 @@ impl QStorage { match self { QStorage::Cpu(storage) => storage.block_size(), QStorage::Metal(storage) => storage.dtype().block_size(), + QStorage::Cuda(storage) => storage.dtype().block_size(), } } @@ -63,6 +74,7 @@ impl QStorage { match self { QStorage::Cpu(storage) => storage.dtype(), QStorage::Metal(storage) => storage.dtype(), + QStorage::Cuda(storage) => storage.dtype(), } } @@ -70,6 +82,7 @@ impl QStorage { match self { QStorage::Cpu(_storage) => Device::Cpu, QStorage::Metal(storage) => Device::Metal(storage.device().clone()), + QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()), } } @@ -77,6 +90,7 @@ impl QStorage { match self { QStorage::Cpu(storage) => storage.storage_size_in_bytes(), QStorage::Metal(storage) => storage.storage_size_in_bytes(), + QStorage::Cuda(storage) => storage.storage_size_in_bytes(), } } @@ -86,6 +100,7 @@ impl QStorage { storage.from_float(src.as_slice::()?)?; } (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, + (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?, _ => crate::bail!("Invalid dequantize storage locations do not match"), } Ok(()) @@ -95,6 +110,7 @@ impl QStorage { match self { QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), + QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)), } } @@ -106,7 +122,7 @@ impl QStorage { let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; Ok(Cow::from(data)) } - QStorage::Metal(_storage) => { + QStorage::Metal(_) | QStorage::Cuda(_) => { crate::bail!("not implemented"); } } @@ -424,7 +440,7 @@ impl crate::CustomOp1 for QTensor { #[allow(clippy::infallible_destructuring_match)] let self_storage = match &self.storage { QStorage::Cpu(storage) => storage, - QStorage::Metal(_) => crate::bail!("Invalid storage"), + QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"), }; let slice = storage.as_slice::()?; let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; @@ -444,6 +460,18 @@ impl crate::CustomOp1 for QTensor { }; self_storage.fwd(&self.shape, storage, layout) } + + fn cuda_fwd( + &self, + storage: &crate::CudaStorage, + layout: &crate::Layout, + ) -> Result<(crate::CudaStorage, Shape)> { + let self_storage = match &self.storage { + QStorage::Cuda(cuda) => cuda, + _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"), + }; + self_storage.fwd(&self.shape, storage, layout) + } } impl crate::Module for QMatMul { diff --git a/candle-examples/examples/custom-ops/cuda_kernels.rs b/candle-examples/examples/custom-ops/cuda_kernels.rs index e69de29b..c00b601b 100644 --- a/candle-examples/examples/custom-ops/cuda_kernels.rs +++ b/candle-examples/examples/custom-ops/cuda_kernels.rs @@ -0,0 +1 @@ +pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx")); diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index 478debd3..dc1195cb 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -4,6 +4,7 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); +pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); diff --git a/candle-kernels/src/quantized.cu b/candle-kernels/src/quantized.cu new file mode 100644 index 00000000..bf81487a --- /dev/null +++ b/candle-kernels/src/quantized.cu @@ -0,0 +1,1536 @@ +// Kernels adapted from llama.cpp ggml-cuda.cu +// https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda.cu +#include "cuda_fp16.h" +#include "cuda_bf16.h" +#include + +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif + +#undef GGML_CUDA_F16 +#define GGML_CUDA_DMMV_X 32 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef uint16_t ggml_fp16_t; +typedef float dfloat; // dequantize float +typedef float2 dfloat2; +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); + +static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { + const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment + + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + +static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { + return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +} + + +#define CUDA_USE_TENSOR_CORES + +#define WARP_SIZE 32 +#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) + +#define CC_PASCAL 600 +#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define CC_VOLTA 700 +#define CC_OFFSET_AMD 1000000 +#define CC_RDNA1 (CC_OFFSET_AMD + 1010) +#define CC_RDNA2 (CC_OFFSET_AMD + 1030) +#define CC_RDNA3 (CC_OFFSET_AMD + 1100) + +#define MMQ_X_Q4_0_RDNA2 64 +#define MMQ_Y_Q4_0_RDNA2 128 +#define NWARPS_Q4_0_RDNA2 8 +#define MMQ_X_Q4_0_RDNA1 64 +#define MMQ_Y_Q4_0_RDNA1 64 +#define NWARPS_Q4_0_RDNA1 8 +#if defined(CUDA_USE_TENSOR_CORES) +#define MMQ_X_Q4_0_AMPERE 4 +#define MMQ_Y_Q4_0_AMPERE 32 +#define NWARPS_Q4_0_AMPERE 4 +#else +#define MMQ_X_Q4_0_AMPERE 64 +#define MMQ_Y_Q4_0_AMPERE 128 +#define NWARPS_Q4_0_AMPERE 4 +#endif +#define MMQ_X_Q4_0_PASCAL 64 +#define MMQ_Y_Q4_0_PASCAL 64 +#define NWARPS_Q4_0_PASCAL 8 + +// QK = number of values after dequantization +// QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization + +#define QK4_0 32 +#define QR4_0 2 +#define QI4_0 (QK4_0 / (4 * QR4_0)) +typedef struct { + half d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +#define QR4_1 2 +#define QI4_1 (QK4_1 / (4 * QR4_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +#define QR5_0 2 +#define QI5_0 (QK5_0 / (4 * QR5_0)) +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +#define QR5_1 2 +#define QI5_1 (QK5_1 / (4 * QR5_1)) +typedef struct { + half2 dm; // dm.x = delta, dm.y = min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +#define QR8_0 1 +#define QI8_0 (QK8_0 / (4 * QR8_0)) +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +#define QR8_1 1 +#define QI8_1 (QK8_1 / (4 * QR8_1)) +typedef struct { + half2 ds; // ds.x = delta, ds.y = sum + int8_t qs[QK8_0]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); +typedef void (*load_tiles_cuda_t)( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row); +typedef float (*vec_dot_q_mul_mat_cuda_t)( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); + +#define QR2_K 4 +#define QI2_K (QK_K / (4*QR2_K)) +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half2 dm; // super-block scale for quantized scales/mins +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +#define QR3_K 4 +#define QI3_K (QK_K / (4*QR3_K)) +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#ifdef GGML_QKK_64 + uint8_t scales[2]; // scales, quantized with 8 bits +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; +//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); + +#define QR4_K 2 +#define QI4_K (QK_K / (4*QR4_K)) +#ifdef GGML_QKK_64 +typedef struct { + half dm[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif + +#define QR5_K 2 +#define QI5_K (QK_K / (4*QR5_K)) +#ifdef GGML_QKK_64 +typedef struct { + half d; // super-block scale + int8_t scales[QK_K/16]; // block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + half2 dm; // super-block scale for quantized scales/mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +#define QR6_K 2 +#define QI6_K (QK_K / (4*QR6_K)) +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales + half d; // delta +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); + + +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q + +#define VDR_Q4_0_Q8_1_MMVQ 2 +#define VDR_Q4_0_Q8_1_MMQ 4 + +template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( + const int * v, const int * u, const float & d4, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = __dp4a(vi0, u[2*i+0], sumi); + sumi = __dp4a(vi1, u[2*i+1], sumi); + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 8 from each quant value + const float res = d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); + printf("%f %f %f %f %f %f\n", res, d4, sumi, ds8f.x, vdr/QI4_0, ds8f.y); + return res; +} + + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { + (void)x_qh; (void)x_sc; + + const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); + const float * x_dmf = (const float *) x_dm; + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + } + + return vec_dot_q4_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); +} + +template +static __device__ __forceinline__ void mul_mat_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int & ncols_dst = ncols_y; + + const int row_dst_0 = blockIdx.x*mmq_y; + const int & row_x_0 = row_dst_0; + + const int col_dst_0 = blockIdx.y*mmq_x; + const int & col_y_0 = col_dst_0; + + int * tile_x_ql = nullptr; + half2 * tile_x_dm = nullptr; + int * tile_x_qh = nullptr; + int * tile_x_sc = nullptr; + + allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); + + __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; + __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; + + float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; + + for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { + + load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, + threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); + +#pragma unroll + for (int ir = 0; ir < qr; ++ir) { + const int kqs = ir*WARP_SIZE + threadIdx.x; + const int kbxd = kqs / QI8_1; + +#pragma unroll + for (int i = 0; i < mmq_x; i += nwarps) { + const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses + + const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; + + const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; + tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); + } + +#pragma unroll + for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int col_y_eff = min(col_y_0 + ids, ncols_y-1); + + // if the sum is not needed it's faster to transform the scale to f32 ahead of time + const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float * dfi_dst = (float *) dsi_dst; + *dfi_dst = __low2half(*dsi_src); + } + } + + __syncthreads(); + +// #pragma unroll // unrolling this loop causes too much register pressure + for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + sum[i/WARP_SIZE][j/nwarps] += vec_dot( + tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, + threadIdx.x + i, threadIdx.y + j, k); + } + } + } + + __syncthreads(); + } + } + +#pragma unroll + for (int j = 0; j < mmq_x; j += nwarps) { + const int col_dst = col_dst_0 + j + threadIdx.y; + + if (col_dst >= ncols_dst) { + return; + } + +#pragma unroll + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + const int row_dst = row_dst_0 + threadIdx.x + i; + + if (row_dst >= nrows_dst) { + continue; + } + + dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; + } + } +} + +template static __device__ __forceinline__ void load_tiles_q4_0( + const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { + (void)x_qh; (void)x_sc; + + const int kbx = k / QI4_0; + const int kqsx = k % QI4_0; + + const block_q4_0 * bx0 = (const block_q4_0 *) vx; + + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; + + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int kbxd = k % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { + int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; + + x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; + } +} + +template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { + (void)x_qh; (void)x_sc; + + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; + + *x_ql = tile_x_qs; + *x_dm = (half2 *) tile_x_d; +} + +extern "C" __global__ void mul_mat_q4_0_check( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q4_0_AMPERE; + const int mmq_y = MMQ_Y_Q4_0_AMPERE; + const int nwarps = NWARPS_Q4_0_AMPERE; + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +extern "C" __global__ void mul_mat_q4_0_no_check( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int mmq_x = MMQ_X_Q4_0_AMPERE; + const int mmq_y = MMQ_Y_Q4_0_AMPERE; + const int nwarps = NWARPS_Q4_0_AMPERE; + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); +} + +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const dfloat d = x[ib].d; + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + +#ifdef GGML_CUDA_F16 + v = __hsub2(v, {8.0f, 8.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 8.0f) * d; + v.y = (v.y - 8.0f) * d; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const dfloat d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_CUDA_F16 + v = __hsub2(v, {16.0f, 16.0f}); + v = __hmul2(v, {d, d}); +#else + v.x = (v.x - 16.0f) * d; + v.y = (v.y - 16.0f) * d; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const dfloat d = __low2half(x[ib].dm); + const dfloat m = __high2half(x[ib].dm); + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); + v = __hadd2(v, {m, m}); +#else + v.x = (v.x * d) + m; + v.y = (v.y * d) + m; +#endif // GGML_CUDA_F16 +} + +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const dfloat d = x[ib].d; + + v.x = x[ib].qs[iqs + 0]; + v.y = x[ib].qs[iqs + 1]; + +#ifdef GGML_CUDA_F16 + v = __hmul2(v, {d, d}); +#else + v.x *= d; + v.y *= d; +#endif // GGML_CUDA_F16 +} + + +template +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { + const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x); + + if (i >= k) { + return; + } + + const int ib = i/qk; // block index + const int iqs = (i%qk)/qr; // quant index + const int iybs = i - i%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + dfloat2 v; + dequantize_kernel(vx, ib, iqs, v); + + y[iybs + iqs + 0] = v.x; + y[iybs + iqs + y_offset] = v.y; +} + +template +static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { + const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); + dequantize_block<<>>(vx, y, k); +} + +extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { + + const int i = blockIdx.x; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + float * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_0 * x = (const block_q4_0 *)vx + ib; + const float d = __half2float(x->d); + const float dm = -8*d; + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l+ 0] = d * (q[l] & 0xF) + dm; + y[l+16] = d * (q[l] >> 4) + dm; + } +} + +extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) { + + const int i = blockIdx.x; + + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + float * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_1 * x = (const block_q4_1 *)vx + ib; + const float2 d = __half22float2(x->dm); + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l+ 0] = d.x * (q[l] & 0xF) + d.y; + y[l+16] = d.x * (q[l] >> 4) + d.y; + } +} + +//================================== k-quants + +extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + + const int tid = threadIdx.x; +#if QK_K == 256 + const int n = tid/32; + const int l = tid - 32*n; + const int is = 8*n + l/16; + + const uint8_t q = x[i].qs[32*n + l]; + float * y = yy + i*QK_K + 128*n; + + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); + y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); + y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); + y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +#else + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const uint8_t q = x[i].qs[il] >> (2*is); + float * y = yy + i*QK_K + 16*is + il; + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); + y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); +#endif + +} + +extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) { + + const int i = blockIdx.x; + const block_q3_K * x = (const block_q3_K *) vx; + +#if QK_K == 256 + const int r = threadIdx.x/4; + const int tid = r/2; + const int is0 = r%2; + const int l0 = 16*is0 + 4*(threadIdx.x%4); + const int n = tid / 4; + const int j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = x[i].d; + float dl = d_all * (us - 32); + + float * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +#else + const int tid = threadIdx.x; + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const int im = il/8; // 0...1 + const int in = il%8; // 0...7 + + float * y = yy + i*QK_K + 16*is + il; + + const uint8_t q = x[i].qs[il] >> (2*is); + const uint8_t h = x[i].hmask[in] >> (2*is + im); + const float d = (float)x[i].d; + + if (is == 0) { + y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } else { + y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } +#endif + +} + +#if QK_K == 256 +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} +#endif + +extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) { + const block_q4_K * x = (const block_q4_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 32 threads + const int tid = threadIdx.x; + const int il = tid/8; + const int ir = tid%8; + const int is = 2*il; + const int n = 4; + + float * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q[l] & 0xF) - m1; + y[l +32] = d2 * (q[l] >> 4) - m2; + } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + float * y = yy + i*QK_K; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); + y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); +#endif +} + +extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) { + const block_q5_K * x = (const block_q5_K *) vx; + + const int i = blockIdx.x; + +#if QK_K == 256 + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int il = tid/16; // il is in 0...3 + const int ir = tid%16; // ir is in 0...15 + const int is = 2*il; // is is in 0...6 + + float * y = yy + i*QK_K + 64*il + 2*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2*il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + const int tid = threadIdx.x; + const uint8_t q = x[i].qs[tid]; + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const int is = tid/16; // 0 or 1 + const uint8_t h = x[i].qh[in] >> im; + const float d = x[i].d; + float * y = yy + i*QK_K + tid; + y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); + y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); +#endif +} + +extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) { + const block_q6_K * x = (const block_q6_K *) vx; + + const int i = blockIdx.x; +#if QK_K == 256 + + // assume 64 threads - this is very slightly better than the one below + const int tid = threadIdx.x; + const int ip = tid/32; // ip is 0 or 1 + const int il = tid - 32*ip; // 0...32 + const int is = 8*ip + il/16; + + float * y = yy + i*QK_K + 128*ip + il; + + const float d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +#else + + // assume 32 threads + const int tid = threadIdx.x; + const int ip = tid/16; // 0 or 1 + const int il = tid - 16*ip; // 0...15 + + float * y = yy + i*QK_K + 16*ip + il; + + const float d = x[i].d; + + const uint8_t ql = x[i].ql[16*ip + il]; + const uint8_t qh = x[i].qh[il] >> (2*ip); + const int8_t * sc = x[i].scales; + + y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); +#endif +} + + +template +static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = blockIdx.x*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int tid = threadIdx.x; + + const int iter_stride = 2*GGML_CUDA_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter + const int y_offset = qr == 1 ? 1 : qk/2; + +// partial sum for each thread +#ifdef GGML_CUDA_F16 + half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_CUDA_F16 + + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index + const int iybs = col - col%qk; // y block start index + +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter + + // dequantize + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + dfloat2 v; + dequantize_kernel(vx, ib, iqs + j/qr, v); + + // matrix multiplication + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 +#ifdef GGML_CUDA_F16 + tmp += __hmul2(v, { + y[iybs + iqs + j/qr + 0], + y[iybs + iqs + j/qr + y_offset] + }); +#else + tmp += v.x * y[iybs + iqs + j/qr + 0]; + tmp += v.y * y[iybs + iqs + j/qr + y_offset]; +#endif // GGML_CUDA_F16 + } + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { +#ifdef GGML_CUDA_F16 + dst[row] = tmp.x + tmp.y; +#else + dst[row] = tmp; +#endif // GGML_CUDA_F16 + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} + +extern "C" __global__ void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} + +extern "C" __global__ void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} + +extern "C" __global__ void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} +extern "C" __global__ void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { + dequantize_mul_mat_vec(vx, y, dst, ncols, nrows); +} + +extern "C" __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q2_K * x = (const block_q2_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 16/K_QUANTS_PER_ITERATION; + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int s_offset = 8*im; + const int y_offset = 128*im + l0; + + uint32_t aux[4]; + const uint8_t * d = (const uint8_t *)aux; + const uint8_t * m = (const uint8_t *)(aux + 2); + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = a[1] & 0x0f0f0f0f; + aux[2] = (a[0] >> 4) & 0x0f0f0f0f; + aux[3] = (a[1] >> 4) & 0x0f0f0f0f; + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) + + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) + + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) + + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) + +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); + sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] + + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; + + } + tmp += dall * sum1 - dmin * sum2; + + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; + + uint32_t uaux[2]; + const uint8_t * d = (const uint8_t *)uaux; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint32_t * s = (const uint32_t *)x[i].scales; + + uaux[0] = s[0] & 0x0f0f0f0f; + uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; + + const float2 dall = __half22float2(x[i].dm); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t ql = q[l]; + sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + + y[l+16] * d[1] * ((ql >> 2) & 3) + + y[l+32] * d[2] * ((ql >> 4) & 3) + + y[l+48] * d[3] * ((ql >> 6) & 3); + sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; + } + tmp += dall.x * sum1 - dall.y * sum2; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q3_K * x = (const block_q3_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = x[i].qs + q_offset; + const uint8_t * h = x[i].hmask + l0; + + const uint16_t * a = (const uint16_t *)x[i].scales; + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = x[i].d; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + + } +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint8_t * s = x[i].scales; + + const float dall = (float)x[i].d; + + float sum = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t hl = x[i].hmask[im+l] >> in; + const uint8_t ql = q[l]; + sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) + + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q4_K * x = (const block_q4_K *)vx + ib0; + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 4; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; + s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif + + } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const float d = (float)x[i].dm[0]; + const float m = (float)x[i].dm[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { + + const int row = blockIdx.x; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; + + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 2; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1 << (2*im); + const uint8_t hm2 = hm1 << 4; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + + uint16_t q16[8]; + const uint8_t * q4 = (const uint8_t *)q16; + + for (int i = ix; i < num_blocks_per_row; i += 2) { + + const uint8_t * ql1 = x[i].qs + q_offset; + const uint8_t * qh = x[i].qh + l0; + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + + float4 sum = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + const uint16_t * q1 = (const uint16_t *)ql1; + const uint16_t * q2 = q1 + 32; + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[8] & 0x0f0f; + q16[2] = (q1[0] >> 4) & 0x0f0f; + q16[3] = (q1[8] >> 4) & 0x0f0f; + q16[4] = q2[0] & 0x0f0f; + q16[5] = q2[8] & 0x0f0f; + q16[6] = (q2[0] >> 4) & 0x0f0f; + q16[7] = (q2[8] >> 4) & 0x0f0f; + for (int l = 0; l < n; ++l) { + sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) + + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); + sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) + + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); + sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) + + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); + sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) + + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); + smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] + + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; + } + tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; + } + +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; + const int im = step/8; + const int in = step%8; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const int8_t * s = x[i].scales; + const float * y = yy + i*QK_K + step; + const float d = x[i].d; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + const uint8_t h = x[i].qh[in+j] >> im; + sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) + + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); + } + tmp += sum; + } +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + +extern "C" __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = blockIdx.x*blockDim.y + threadIdx.y; + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + const block_q6_K * x = (const block_q6_K *)vx + ib0; + +#if QK_K == 256 + + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = x[i].ql + ql_offset; + const uint8_t * qh = x[i].qh + qh_offset; + const int8_t * s = x[i].scales + s_offset; + + const float d = x[i].d; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = x[i].ql + step; + const uint8_t * qh = x[i].qh + step; + const int8_t * s = x[i].scales; + + const float d = x[i+0].d; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (tid == 0) { + dst[row] = tmp; + } +}