Simd128 vec-dot for q4_0. (#974)
* Simd128 vec-dot for q4_0. * Bugfix. * Add wasm tests. * Bugfix for the q40 vecdot. * More quantization tests.
This commit is contained in:
parent
e59784e353
commit
667f01c173
|
@ -12,8 +12,9 @@ members = [
|
|||
"candle-wasm-examples/whisper",
|
||||
"candle-wasm-examples/yolo",
|
||||
"candle-wasm-examples/bert",
|
||||
"candle-wasm-examples/t5",
|
||||
"candle-wasm-examples/phi",
|
||||
"candle-wasm-examples/t5",
|
||||
"candle-wasm-tests",
|
||||
]
|
||||
exclude = ["candle-flash-attn", "candle-kernels"]
|
||||
resolver = "2"
|
||||
|
|
|
@ -225,6 +225,9 @@ impl GgmlType for BlockQ4_0 {
|
|||
#[cfg(target_feature = "neon")]
|
||||
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
#[cfg(target_feature = "simd128")]
|
||||
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
|
||||
|
||||
let qk = QK8_0;
|
||||
let nb = n / qk;
|
||||
if n % QK8_0 != 0 {
|
||||
|
|
|
@ -1,9 +1,59 @@
|
|||
use super::k_quants::{BlockQ8_0, QK8_0};
|
||||
use super::k_quants::{BlockQ4_0, BlockQ8_0, QK8_0};
|
||||
use crate::Result;
|
||||
use half::f16;
|
||||
|
||||
use core::arch::wasm32::*;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
if n % QK8_0 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
|
||||
}
|
||||
let nb = n / QK8_0;
|
||||
if nb % 2 != 0 {
|
||||
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
|
||||
}
|
||||
unsafe {
|
||||
let mut acc = f32x4_splat(0.0f32);
|
||||
for (x, y) in xs.iter().zip(ys.iter()) {
|
||||
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
|
||||
let x12 = v128_and(x1234, u8x16_splat(0x0F));
|
||||
let x12 = i8x16_sub(x12, i8x16_splat(8));
|
||||
let x34 = u8x16_shr(x1234, 4);
|
||||
let x34 = i8x16_sub(x34, i8x16_splat(8));
|
||||
|
||||
let x1 = i16x8_extend_low_i8x16(x12);
|
||||
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
|
||||
let sum_xy = i32x4_dot_i16x8(x1, y1);
|
||||
|
||||
let x2 = i16x8_extend_high_i8x16(x12);
|
||||
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
|
||||
|
||||
let x3 = i16x8_extend_low_i8x16(x34);
|
||||
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
|
||||
|
||||
let x4 = i16x8_extend_high_i8x16(x34);
|
||||
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
|
||||
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
|
||||
|
||||
let sum_xy = f32x4_convert_i32x4(sum_xy);
|
||||
|
||||
// f32x4_relaxed_madd is nightly only.
|
||||
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
|
||||
let scaled = f32x4_mul(sum_xy, d);
|
||||
acc = f32x4_add(acc, scaled)
|
||||
}
|
||||
let res = f32x4_extract_lane::<0>(acc)
|
||||
+ f32x4_extract_lane::<1>(acc)
|
||||
+ f32x4_extract_lane::<2>(acc)
|
||||
+ f32x4_extract_lane::<3>(acc);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
|
||||
let qk = QK8_0;
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
[package]
|
||||
name = "candle-wasm-tests"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description = "WASM tests for candle"
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
rand = { workspace = true }
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
[dev-dependencies]
|
||||
wasm-bindgen-test = "0.3.0"
|
|
@ -0,0 +1,12 @@
|
|||
Run the tests with:
|
||||
```bash
|
||||
RUST_LOG=wasm_bindgen_test_runner wasm-pack test --chrome --headless
|
||||
```
|
||||
Or:
|
||||
```bash
|
||||
wasm-pack test --chrome
|
||||
```
|
||||
|
||||
If you get an "invalid session id" failure in headless mode, check that logs and
|
||||
it may well be that your ChromeDriver is not at the same version as your
|
||||
browser.
|
|
@ -0,0 +1,14 @@
|
|||
pub fn add(left: usize, right: usize) -> usize {
|
||||
left + right
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
use candle::{
|
||||
quantized::{self, k_quants, GgmlDType, GgmlType},
|
||||
test_utils::to_vec2_round,
|
||||
Device, Result, Tensor,
|
||||
};
|
||||
|
||||
use wasm_bindgen_test::*;
|
||||
wasm_bindgen_test_configure!(run_in_browser);
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn quantized_matmul_neg() -> Result<()> {
|
||||
let cpu = &Device::Cpu;
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..k * n)
|
||||
.map(|v| v as f32 - (k * n) as f32 / 3.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||
&[
|
||||
243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0,
|
||||
-196472.0, 63012.0, 324585.0, 587902.0
|
||||
]
|
||||
);
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&mm, 0)?,
|
||||
&[
|
||||
[244064.0, -20128.0, -284320.0, -548512.0],
|
||||
[23563.0, 21515.0, 19467.0, 17419.0],
|
||||
[-196939.0, 63157.0, 323253.0, 583349.0]
|
||||
]
|
||||
);
|
||||
|
||||
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor);
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
[23777.0, 21651.0, 19398.0, 18367.0],
|
||||
[-196472.0, 63012.0, 324585.0, 587902.0]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||
const GGML_TEST_SIZE: usize = 32 * 128;
|
||||
(0..GGML_TEST_SIZE)
|
||||
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Very simple dot product implementation
|
||||
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
|
||||
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||
}
|
||||
|
||||
/// Returns the error achieved by the GGML matmul unit test.
|
||||
fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
|
||||
let err = match dtype {
|
||||
GgmlDType::F16 => 0.000010,
|
||||
GgmlDType::Q2K => 0.004086,
|
||||
GgmlDType::Q3K => 0.016148,
|
||||
GgmlDType::Q4K => 0.002425,
|
||||
GgmlDType::Q5K => 0.000740,
|
||||
GgmlDType::Q6K => 0.000952,
|
||||
GgmlDType::Q4_0 => 0.001143,
|
||||
GgmlDType::Q4_1 => 0.007784,
|
||||
GgmlDType::Q5_0 => 0.001353,
|
||||
GgmlDType::Q5_1 => 0.001363,
|
||||
GgmlDType::Q8_0 => 0.000092,
|
||||
_ => candle::bail!("No GGML results for quantization type {dtype:?}",),
|
||||
};
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
|
||||
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
|
||||
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
|
||||
let a = create_ggml_like_vector(0.0);
|
||||
let b = create_ggml_like_vector(1.0);
|
||||
let length = a.len();
|
||||
|
||||
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
|
||||
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
|
||||
T::from_float(&a, &mut a_quant)?;
|
||||
T::VecDotType::from_float(&b, &mut b_quant)?;
|
||||
|
||||
let result = T::vec_dot(length, &a_quant, &b_quant)?;
|
||||
let reference_result = vec_dot_reference(&a, &b);
|
||||
|
||||
let error = (result - reference_result).abs() / length as f32;
|
||||
|
||||
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
|
||||
|
||||
if error > GGML_MAX_DOT_PRODUCT_ERROR {
|
||||
candle::bail!(
|
||||
"Dot product error {} exceeds max error {}",
|
||||
error,
|
||||
GGML_MAX_DOT_PRODUCT_ERROR
|
||||
);
|
||||
}
|
||||
|
||||
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
|
||||
// => we use a slightly higher error threshold
|
||||
const ERROR_LENIENCY: f32 = 0.00001;
|
||||
if error - ERROR_LENIENCY > ggml_error {
|
||||
candle::bail!(
|
||||
"Dot product error {} exceeds ggml reference error {}",
|
||||
error,
|
||||
ggml_error
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn quantized_matmul_q40() -> Result<()> {
|
||||
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ4_0>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[wasm_bindgen_test]
|
||||
fn quantized_matmul_q80() -> Result<()> {
|
||||
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ8_0>()?;
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
{
|
||||
"moz:firefoxOptions": {
|
||||
"prefs": {
|
||||
"media.navigator.streams.fake": true,
|
||||
"media.navigator.permission.disabled": true
|
||||
},
|
||||
"args": []
|
||||
},
|
||||
"goog:chromeOptions": {
|
||||
"args": [
|
||||
"--use-fake-device-for-media-stream",
|
||||
"--use-fake-ui-for-media-stream"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue