From 667f01c17323a5c28a9ae12d9f4512c36cc411b9 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 27 Sep 2023 14:15:30 +0100 Subject: [PATCH] Simd128 vec-dot for q4_0. (#974) * Simd128 vec-dot for q4_0. * Bugfix. * Add wasm tests. * Bugfix for the q40 vecdot. * More quantization tests. --- Cargo.toml | 3 +- candle-core/src/quantized/k_quants.rs | 3 + candle-core/src/quantized/simd128.rs | 52 +++++++- candle-wasm-tests/Cargo.toml | 15 +++ candle-wasm-tests/README.md | 12 ++ candle-wasm-tests/src/lib.rs | 14 +++ candle-wasm-tests/tests/quantized_tests.rs | 140 +++++++++++++++++++++ candle-wasm-tests/webdriver.json | 16 +++ 8 files changed, 253 insertions(+), 2 deletions(-) create mode 100644 candle-wasm-tests/Cargo.toml create mode 100644 candle-wasm-tests/README.md create mode 100644 candle-wasm-tests/src/lib.rs create mode 100644 candle-wasm-tests/tests/quantized_tests.rs create mode 100644 candle-wasm-tests/webdriver.json diff --git a/Cargo.toml b/Cargo.toml index bcb90217..c502fcc5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,8 +12,9 @@ members = [ "candle-wasm-examples/whisper", "candle-wasm-examples/yolo", "candle-wasm-examples/bert", - "candle-wasm-examples/t5", "candle-wasm-examples/phi", + "candle-wasm-examples/t5", + "candle-wasm-tests", ] exclude = ["candle-flash-attn", "candle-kernels"] resolver = "2" diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index e0218c34..064692b7 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -225,6 +225,9 @@ impl GgmlType for BlockQ4_0 { #[cfg(target_feature = "neon")] return super::neon::vec_dot_q4_0_q8_0(n, xs, ys); + #[cfg(target_feature = "simd128")] + return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys); + let qk = QK8_0; let nb = n / qk; if n % QK8_0 != 0 { diff --git a/candle-core/src/quantized/simd128.rs b/candle-core/src/quantized/simd128.rs index 9cb7119f..c093f189 100644 --- a/candle-core/src/quantized/simd128.rs +++ b/candle-core/src/quantized/simd128.rs @@ -1,9 +1,59 @@ -use super::k_quants::{BlockQ8_0, QK8_0}; +use super::k_quants::{BlockQ4_0, BlockQ8_0, QK8_0}; use crate::Result; use half::f16; use core::arch::wasm32::*; +#[inline(always)] +pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result { + let qk = QK8_0; + if n % QK8_0 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}") + } + let nb = n / QK8_0; + if nb % 2 != 0 { + crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even") + } + unsafe { + let mut acc = f32x4_splat(0.0f32); + for (x, y) in xs.iter().zip(ys.iter()) { + let x1234 = v128_load(x.qs.as_ptr() as *const v128); + let x12 = v128_and(x1234, u8x16_splat(0x0F)); + let x12 = i8x16_sub(x12, i8x16_splat(8)); + let x34 = u8x16_shr(x1234, 4); + let x34 = i8x16_sub(x34, i8x16_splat(8)); + + let x1 = i16x8_extend_low_i8x16(x12); + let y1 = i16x8_load_extend_i8x8(y.qs.as_ptr()); + let sum_xy = i32x4_dot_i16x8(x1, y1); + + let x2 = i16x8_extend_high_i8x16(x12); + let y2 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(8)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x2, y2)); + + let x3 = i16x8_extend_low_i8x16(x34); + let y3 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(16)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x3, y3)); + + let x4 = i16x8_extend_high_i8x16(x34); + let y4 = i16x8_load_extend_i8x8(y.qs.as_ptr().add(24)); + let sum_xy = i32x4_add(sum_xy, i32x4_dot_i16x8(x4, y4)); + + let sum_xy = f32x4_convert_i32x4(sum_xy); + + // f32x4_relaxed_madd is nightly only. + let d = f32x4_splat(f16::to_f32(x.d) * f16::to_f32(y.d)); + let scaled = f32x4_mul(sum_xy, d); + acc = f32x4_add(acc, scaled) + } + let res = f32x4_extract_lane::<0>(acc) + + f32x4_extract_lane::<1>(acc) + + f32x4_extract_lane::<2>(acc) + + f32x4_extract_lane::<3>(acc); + Ok(res) + } +} + #[inline(always)] pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) -> Result { let qk = QK8_0; diff --git a/candle-wasm-tests/Cargo.toml b/candle-wasm-tests/Cargo.toml new file mode 100644 index 00000000..5d8553d7 --- /dev/null +++ b/candle-wasm-tests/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "candle-wasm-tests" +version.workspace = true +edition.workspace = true +description = "WASM tests for candle" +keywords.workspace = true +categories.workspace = true + +[dependencies] +candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } +rand = { workspace = true } +getrandom = { version = "0.2", features = ["js"] } + +[dev-dependencies] +wasm-bindgen-test = "0.3.0" diff --git a/candle-wasm-tests/README.md b/candle-wasm-tests/README.md new file mode 100644 index 00000000..3b4d5c4e --- /dev/null +++ b/candle-wasm-tests/README.md @@ -0,0 +1,12 @@ +Run the tests with: +```bash +RUST_LOG=wasm_bindgen_test_runner wasm-pack test --chrome --headless +``` +Or: +```bash +wasm-pack test --chrome +``` + +If you get an "invalid session id" failure in headless mode, check that logs and +it may well be that your ChromeDriver is not at the same version as your +browser. diff --git a/candle-wasm-tests/src/lib.rs b/candle-wasm-tests/src/lib.rs new file mode 100644 index 00000000..7d12d9af --- /dev/null +++ b/candle-wasm-tests/src/lib.rs @@ -0,0 +1,14 @@ +pub fn add(left: usize, right: usize) -> usize { + left + right +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn it_works() { + let result = add(2, 2); + assert_eq!(result, 4); + } +} diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs new file mode 100644 index 00000000..0594a4fa --- /dev/null +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -0,0 +1,140 @@ +use candle::{ + quantized::{self, k_quants, GgmlDType, GgmlType}, + test_utils::to_vec2_round, + Device, Result, Tensor, +}; + +use wasm_bindgen_test::*; +wasm_bindgen_test_configure!(run_in_browser); + +#[wasm_bindgen_test] +fn quantized_matmul_neg() -> Result<()> { + let cpu = &Device::Cpu; + let (m, k, n) = (3, 64, 4); + let lhs = (0..(m * k)) + .map(|v| v as f32 - (m * k) as f32 / 2.0) + .collect::>(); + let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?; + let mut dst = vec![42.; 3 * 4]; + let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; + let rhs = (0..k * n) + .map(|v| v as f32 - (k * n) as f32 / 3.0) + .collect::>(); + let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?; + k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; + k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; + assert_eq!( + dst.iter().map(|x| x.round()).collect::>(), + &[ + 243524.0, -19596.0, -285051.0, -549815.0, 23777.0, 21651.0, 19398.0, 18367.0, + -196472.0, 63012.0, 324585.0, 587902.0 + ] + ); + let mm = tensor_lhs.matmul(&tensor_rhs)?; + assert_eq!( + to_vec2_round(&mm, 0)?, + &[ + [244064.0, -20128.0, -284320.0, -548512.0], + [23563.0, 21515.0, 19467.0, 17419.0], + [-196939.0, 63157.0, 323253.0, 583349.0] + ] + ); + + let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; + let matmul = quantized::QMatMul::from_qtensor(qtensor); + let res = matmul.forward(&tensor_lhs)?; + assert_eq!( + to_vec2_round(&res, 0)?, + &[ + [243524.0, -19596.0, -285051.0, -549815.0], + [23777.0, 21651.0, 19398.0, 18367.0], + [-196472.0, 63012.0, 324585.0, 587902.0] + ] + ); + + Ok(()) +} + +/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 +fn create_ggml_like_vector(offset: f32) -> Vec { + const GGML_TEST_SIZE: usize = 32 * 128; + (0..GGML_TEST_SIZE) + .map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) + .collect() +} + +/// Very simple dot product implementation +fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(a, b)| a * b).sum() +} + +/// Returns the error achieved by the GGML matmul unit test. +fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result { + let err = match dtype { + GgmlDType::F16 => 0.000010, + GgmlDType::Q2K => 0.004086, + GgmlDType::Q3K => 0.016148, + GgmlDType::Q4K => 0.002425, + GgmlDType::Q5K => 0.000740, + GgmlDType::Q6K => 0.000952, + GgmlDType::Q4_0 => 0.001143, + GgmlDType::Q4_1 => 0.007784, + GgmlDType::Q5_0 => 0.001353, + GgmlDType::Q5_1 => 0.001363, + GgmlDType::Q8_0 => 0.000092, + _ => candle::bail!("No GGML results for quantization type {dtype:?}",), + }; + Ok(err) +} + +/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 +fn ggml_matmul_error_test() -> Result<()> { + const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02; + let a = create_ggml_like_vector(0.0); + let b = create_ggml_like_vector(1.0); + let length = a.len(); + + let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; + let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; + T::from_float(&a, &mut a_quant)?; + T::VecDotType::from_float(&b, &mut b_quant)?; + + let result = T::vec_dot(length, &a_quant, &b_quant)?; + let reference_result = vec_dot_reference(&a, &b); + + let error = (result - reference_result).abs() / length as f32; + + let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; + + if error > GGML_MAX_DOT_PRODUCT_ERROR { + candle::bail!( + "Dot product error {} exceeds max error {}", + error, + GGML_MAX_DOT_PRODUCT_ERROR + ); + } + + // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML + // => we use a slightly higher error threshold + const ERROR_LENIENCY: f32 = 0.00001; + if error - ERROR_LENIENCY > ggml_error { + candle::bail!( + "Dot product error {} exceeds ggml reference error {}", + error, + ggml_error + ); + } + Ok(()) +} + +#[wasm_bindgen_test] +fn quantized_matmul_q40() -> Result<()> { + ggml_matmul_error_test::()?; + Ok(()) +} + +#[wasm_bindgen_test] +fn quantized_matmul_q80() -> Result<()> { + ggml_matmul_error_test::()?; + Ok(()) +} diff --git a/candle-wasm-tests/webdriver.json b/candle-wasm-tests/webdriver.json new file mode 100644 index 00000000..7e28821b --- /dev/null +++ b/candle-wasm-tests/webdriver.json @@ -0,0 +1,16 @@ +{ + "moz:firefoxOptions": { + "prefs": { + "media.navigator.streams.fake": true, + "media.navigator.permission.disabled": true + }, + "args": [] + }, + "goog:chromeOptions": { + "args": [ + "--use-fake-device-for-media-stream", + "--use-fake-ui-for-media-stream" + ] + } +} +