Simd128 vec-dot for q4_0. (#974)

* Simd128 vec-dot for q4_0.

* Bugfix.

* Add wasm tests.

* Bugfix for the q40 vecdot.

* More quantization tests.
This commit is contained in:
Laurent Mazare 2023-09-27 14:15:30 +01:00 committed by GitHub
parent e59784e353
commit 667f01c173
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 253 additions and 2 deletions

View File

@ -12,8 +12,9 @@ members = [
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
"candle-wasm-examples/t5",
"candle-wasm-examples/phi",
"candle-wasm-examples/t5",
"candle-wasm-tests",
]
exclude = ["candle-flash-attn", "candle-kernels"]
resolver = "2"

View File

@ -225,6 +225,9 @@ impl GgmlType for BlockQ4_0 {
#[cfg(target_feature = "neon")]
return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
#[cfg(target_feature = "simd128")]
return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {

View File

@ -1,9 +1,59 @@
use super::k_quants::{BlockQ8_0, QK8_0};
use super::k_quants::{BlockQ4_0, BlockQ8_0, QK8_0};
use crate::Result;
use half::f16;
use core::arch::wasm32::*;
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
let x1234 = v128_load(x.qs.as_ptr() as *const v128);
let x12 = v128_and(x1234, u8x16_splat(0x0F));
let x12 = i8x16_sub(x12, i8x16_splat(8));
let x34 = u8x16_shr(x1234, 4);
let x34 = i8x16_sub(x34, i8x16_splat(8));
let x1 = i16x8_extend_low_i8x16(x12);
let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr());
let sum_xy = i32x4_dot_i16x8(x1, y1);
let x2 = i16x8_extend_high_i8x16(x12);
let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2));
let x3 = i16x8_extend_low_i8x16(x34);
let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3));
let x4 = i16x8_extend_high_i8x16(x34);
let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24));
let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4));
let sum_xy = f32x4_convert_i32x4(sum_xy);
// f32x4_relaxed_madd is nightly only.
let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d));
let scaled = f32x4_mul(sum_xy, d);
acc = f32x4_add(acc, scaled)
}
let res = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
Ok(res)
}
}
#[inline(always)]
pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;

View File

@ -0,0 +1,15 @@
[package]
name = "candle-wasm-tests"
version.workspace = true
edition.workspace = true
description = "WASM tests for candle"
keywords.workspace = true
categories.workspace = true
[dependencies]
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
rand = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }
[dev-dependencies]
wasm-bindgen-test = "0.3.0"

View File

@ -0,0 +1,12 @@
Run the tests with:
```bash
RUST_LOG=wasm_bindgen_test_runner wasm-pack test --chrome --headless
```
Or:
```bash
wasm-pack test --chrome
```
If you get an "invalid session id" failure in headless mode, check that logs and
it may well be that your ChromeDriver is not at the same version as your
browser.

View File

@ -0,0 +1,14 @@
pub fn add(left: usize, right: usize) -> usize {
left + right
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}

View File

@ -0,0 +1,140 @@
use candle::{
quantized::{self, k_quants, GgmlDType, GgmlType},
test_utils::to_vec2_round,
Device, Result, Tensor,
};
use wasm_bindgen_test::*;
wasm_bindgen_test_configure!(run_in_browser);
#[wasm_bindgen_test]
fn quantized_matmul_neg() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n)
.map(|v| v as f32 - (k * n) as f32 / 3.0)
.collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!(
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
&[
243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0,
-196472.0, 63012.0, 324585.0, 587902.0
]
);
let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!(
to_vec2_round(&mm, 0)?,
&[
[244064.0, -20128.0, -284320.0, -548512.0],
[23563.0, 21515.0, 19467.0, 17419.0],
[-196939.0, 63157.0, 323253.0, 583349.0]
]
);
let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor);
let res = matmul.forward(&tensor_lhs)?;
assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
[23777.0, 21651.0, 19398.0, 18367.0],
[-196472.0, 63012.0, 324585.0, 587902.0]
]
);
Ok(())
}
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
const GGML_TEST_SIZE: usize = 32 * 128;
(0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
.collect()
}
/// Very simple dot product implementation
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum()
}
/// Returns the error achieved by the GGML matmul unit test.
fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
let err = match dtype {
GgmlDType::F16 => 0.000010,
GgmlDType::Q2K => 0.004086,
GgmlDType::Q3K => 0.016148,
GgmlDType::Q4K => 0.002425,
GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143,
GgmlDType::Q4_1 => 0.007784,
GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363,
GgmlDType::Q8_0 => 0.000092,
_ => candle::bail!("No GGML results for quantization type {dtype:?}",),
};
Ok(err)
}
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0);
let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
T::from_float(&a, &mut a_quant)?;
T::VecDotType::from_float(&b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b);
let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?;
if error > GGML_MAX_DOT_PRODUCT_ERROR {
candle::bail!(
"Dot product error {} exceeds max error {}",
error,
GGML_MAX_DOT_PRODUCT_ERROR
);
}
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error {
candle::bail!(
"Dot product error {} exceeds ggml reference error {}",
error,
ggml_error
);
}
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q40() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ4_0>()?;
Ok(())
}
#[wasm_bindgen_test]
fn quantized_matmul_q80() -> Result<()> {
ggml_matmul_error_test::<candle::quantized::k_quants::BlockQ8_0>()?;
Ok(())
}

View File

@ -0,0 +1,16 @@
{
"moz:firefoxOptions": {
"prefs": {
"media.navigator.streams.fake": true,
"media.navigator.permission.disabled": true
},
"args": []
},
"goog:chromeOptions": {
"args": [
"--use-fake-device-for-media-stream",
"--use-fake-ui-for-media-stream"
]
}
}