From aa79e36a8de31e93579b02c2103ec0e2f93f4ff3 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 17 Sep 2024 10:08:20 -0400 Subject: [PATCH] Add more quantization support for burn-jit (#2275) * Add cubecl quantization kernels and QTensorOps for burn-jit * Fix typo * Fix output vec factor * Fix output dtype size_of * Remove unused code in dequantize test * Fix dequantize vectorization * Handle tensors when number of elems is not a multiple of 4 * Support quantize for tensors with less than 4 elems (no vectorization) * Fix equal 0 test * Add quantize/dequantize tests * Add q_to_device * Refactor kernels for latest cubecl * intermediate i32 cast * Fix size_of output type * Use strict=false to ignore floating point precision issues with qparams equality * Only check that lhs & rhs strategies match (but not strict on qparams values) * Use assert_approx_eq on dequant values * Reduce precision for flaky test * Remove todo comment * Add comment for cast to unsigned * More comment --------- Co-authored-by: louisfd --- crates/burn-jit/src/backend.rs | 3 +- crates/burn-jit/src/kernel/mod.rs | 2 + .../src/kernel/quantization/dequantize.rs | 211 +++++++++++++++++ .../burn-jit/src/kernel/quantization/mod.rs | 5 + .../src/kernel/quantization/quantize.rs | 219 ++++++++++++++++++ crates/burn-jit/src/ops/qtensor.rs | 130 +++++++++-- crates/burn-jit/src/tensor/qtensor.rs | 91 +++++++- crates/burn-jit/src/tests/mod.rs | 11 + crates/burn-jit/src/tests/quantization.rs | 82 +++++++ crates/burn-tensor/src/tensor/data.rs | 27 ++- .../src/tests/quantization/ops/aggregation.rs | 2 +- .../src/tests/quantization/ops/quantize.rs | 11 +- 12 files changed, 744 insertions(+), 50 deletions(-) create mode 100644 crates/burn-jit/src/kernel/quantization/dequantize.rs create mode 100644 crates/burn-jit/src/kernel/quantization/mod.rs create mode 100644 crates/burn-jit/src/kernel/quantization/quantize.rs create mode 100644 crates/burn-jit/src/tests/quantization.rs diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index d3945d3c5..720349fed 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -34,7 +34,8 @@ where type FloatTensorPrimitive = JitTensor; type IntTensorPrimitive = JitTensor; type BoolTensorPrimitive = JitTensor; - type QuantizedTensorPrimitive = QJitTensor; + type QuantizedTensorPrimitive = + QJitTensor; fn name() -> String { format!("jit<{}>", R::name()) diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index cb3bbb8a1..724a94e2d 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -25,6 +25,8 @@ pub mod matmul; pub mod pool; /// Pseudo-random number generator kernels pub mod prng; +/// Quantization operations +pub mod quantization; /// Reduction algorithms pub mod reduce; diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs new file mode 100644 index 000000000..5480486f1 --- /dev/null +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -0,0 +1,211 @@ +use crate::tensor::{JitTensor, QJitTensor}; +use crate::FloatElement; +use crate::{IntElement, JitElement, JitRuntime}; +use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; +use cubecl::calculate_cube_count_elemwise; +use cubecl::prelude::*; + +#[cube] +pub(crate) fn dequantize_affine_int8(value: i32, scale: F, offset: i32) -> F { + // x = scale * (x_q - offset) + scale * (F::cast_from(value) - F::cast_from(offset)) +} + +#[cube] +pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 { + // Extract 8-bit segment + let value = (value >> offset) & 0xFF; + // Check if the value is negative by inspecting the MSB and subtract 256 if it is + // Subtract 0 or 256 to circumvent unsupported conditional assignment (let x = if {} else {};) + let sub = i32::cast_from(value & 0x80 != 0) * 256; + i32::cast_from(value) - sub +} + +#[cube(launch_unchecked)] +pub(crate) fn dequantize_per_tensor_affine_int8_kernel( + input: &Tensor, + scale: &Tensor, + offset: &Tensor, + output: &mut Tensor, + #[comptime] vectorized: bool, +) { + if ABSOLUTE_POS >= output.len() { + return; + } + + let scale = scale[0]; + let offset = offset[0]; + + let num_packed = 4; + let value = input[ABSOLUTE_POS]; + let output_pos = ABSOLUTE_POS * num_packed; + + if vectorized { + let vectorization_factor = vectorization_of(input); + #[unroll] + for i in 0..vectorization_factor { + // Extract each 8-bit segment + let v1 = extract_i8(value[i], 24); + let v2 = extract_i8(value[i], 16); + let v3 = extract_i8(value[i], 8); + let v4 = extract_i8(value[i], 0); + + output[output_pos * vectorization_factor + i * num_packed] = + dequantize_affine_int8::(v1, scale, offset); + output[output_pos * vectorization_factor + i * num_packed + 1] = + dequantize_affine_int8::(v2, scale, offset); + output[output_pos * vectorization_factor + i * num_packed + 2] = + dequantize_affine_int8::(v3, scale, offset); + output[output_pos * vectorization_factor + i * num_packed + 3] = + dequantize_affine_int8::(v4, scale, offset); + } + } else { + // Extract each 8-bit segment + let v1 = extract_i8(value, 24); + let v2 = extract_i8(value, 16); + let v3 = extract_i8(value, 8); + let v4 = extract_i8(value, 0); + + output[output_pos] = dequantize_affine_int8::(v1, scale, offset); + output[output_pos + 1] = dequantize_affine_int8::(v2, scale, offset); + output[output_pos + 2] = dequantize_affine_int8::(v3, scale, offset); + output[output_pos + 3] = dequantize_affine_int8::(v4, scale, offset); + } +} + +#[cube] +pub(crate) fn dequantize_symmetric_int8(value: i32, scale: F) -> F { + // x = scale * x_q + scale * F::cast_from(value) +} + +// Would have wrapped symmetric with the same affine kernel but cube doesn't support Option for offset. +#[cube(launch_unchecked)] +pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( + input: &Tensor, + scale: &Tensor, + output: &mut Tensor, + #[comptime] vectorized: bool, +) { + if ABSOLUTE_POS >= output.len() { + return; + } + + let scale = scale[0]; + + let num_packed = 4; + let value = input[ABSOLUTE_POS]; + let output_pos = ABSOLUTE_POS * num_packed; + + if vectorized { + let vectorization_factor = vectorization_of(input); + #[unroll] + for i in 0..vectorization_factor { + for j in 0..num_packed { + let output_idx = output_pos * vectorization_factor + i * num_packed + j; + if output_idx >= output.len() { + return; // value not quantized (padding) + } + // Extract each 8-bit segment + let v = extract_i8(value[i], (3 - j) * 8); + output[output_idx] = dequantize_symmetric_int8::(v, scale); + } + } + } else { + // Extract each 8-bit segment + for j in 0..num_packed { + let output_idx = output_pos + j; + if output_idx >= output.len() { + return; // value not quantized (padding) + } + // Extract each 8-bit segment + let v = extract_i8(value, (3 - j) * 8); + output[output_pos + j] = dequantize_symmetric_int8::(v, scale); + } + } +} + +pub(crate) fn dequantize_per_tensor( + tensor: JitTensor, + scale: JitTensor, + offset: Option>, +) -> JitTensor +where + R: JitRuntime, + F: JitElement, + I: IntElement, +{ + // The actual number of elements is 1/4 (four int8 values packed in a single u32) + // so we choose a vectorization factor to match a valid input binding size. + let num_out_elems = tensor.shape.num_elements(); + let num_elems = usize::div_ceil(num_out_elems, 4); + let vectorization_factor = [4u8, 2, 1] + .iter() + .filter_map(|&v| { + if num_elems >= v as usize { + Some(v) + } else { + None + } + }) + .next() + .unwrap(); + let cube_dim = CubeDim::default(); + let cube_count = + calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); + + let shape_output = tensor.shape.clone(); + let client = tensor.client.clone(); + let handle = client.empty(num_out_elems * core::mem::size_of::()); + let output = + JitTensor::new_contiguous(client.clone(), tensor.device.clone(), shape_output, handle); + + let dummy_array = [1; D]; + if let Some(offset) = offset { + unsafe { + dequantize_per_tensor_affine_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg(vectorization_factor), + // Ignore shape and stride + TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1), + TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1), + output.as_tensor_arg(1), + vectorization_factor > 1, + ) + }; + } else { + unsafe { + dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg(vectorization_factor), + // Ignore shape and stride + TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1), + output.as_tensor_arg(1), + vectorization_factor > 1, + ) + }; + } + + output +} + +/// Convert the tensor back to a higher precision data type. +pub fn dequantize(tensor: QJitTensor) -> JitTensor +where + R: JitRuntime, + F: FloatElement, + I: IntElement, +{ + match tensor.scheme { + QuantizationScheme::PerTensorAffine(dtype) + | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { + QuantizationType::QInt8 => { + dequantize_per_tensor(tensor.qtensor, tensor.qparams.scale, tensor.qparams.offset) + } + }, + } +} diff --git a/crates/burn-jit/src/kernel/quantization/mod.rs b/crates/burn-jit/src/kernel/quantization/mod.rs new file mode 100644 index 000000000..a0244df01 --- /dev/null +++ b/crates/burn-jit/src/kernel/quantization/mod.rs @@ -0,0 +1,5 @@ +mod dequantize; +mod quantize; + +pub use dequantize::*; +pub use quantize::*; diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs new file mode 100644 index 000000000..820177d60 --- /dev/null +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -0,0 +1,219 @@ +use crate::tensor::{JitQuantizationParameters, JitTensor, QJitTensor}; +use crate::FloatElement; +use crate::{IntElement, JitElement, JitRuntime}; +use burn_tensor::quantization::{QuantizationScheme, QuantizationType}; +use cubecl::calculate_cube_count_elemwise; +use cubecl::prelude::*; + +#[cube] +pub(crate) fn quantize_affine_int8( + value: F, + scale: F, + offset: i32, + range_min: F, + range_max: F, +) -> u32 { + let offset = F::cast_from(offset); + + // x_q = clamp(round(x / scale + offset), a, b) + // NOTE: we add 256 before casting to unsigned to correctly represent negative values + u32::cast_from( + i32::cast_from(F::clamp( + F::round((value / scale) + offset), + range_min, + range_max, + )) + 256, + ) +} + +#[cube(launch_unchecked)] +pub(crate) fn quantize_per_tensor_affine_int8_kernel( + input: &Tensor, + scale: &Tensor, + offset: &Tensor, + range_min: f32, + range_max: f32, + output: &mut Tensor, + #[comptime] vectorized: bool, +) { + if ABSOLUTE_POS >= output.len() { + return; + } + + let scale = scale[0]; + let offset = offset[0]; + + let num_packed = 4; + let mut v_packed = 0; + + if vectorized { + // Assuming a vectorization factor of 4 (equal to the number of values packed) + let value = input[ABSOLUTE_POS]; + let vectorization_factor = vectorization_of(input); + #[unroll] + for i in 0..vectorization_factor { + let v = quantize_affine_int8::(value[i], scale, offset, range_min, range_max); + // Shift and combine into u32 + v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); + } + } else { + for i in 0..num_packed { + let v = quantize_affine_int8::( + input[ABSOLUTE_POS + i], + scale, + offset, + range_min, + range_max, + ); + // Shift and combine into u32 + v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); + } + } + + output[ABSOLUTE_POS] = v_packed; +} + +#[cube] +pub(crate) fn quantize_symmetric_int8( + value: F, + scale: F, + range_min: F, + range_max: F, +) -> u32 { + // x_q = clamp(round(x / scale), a, b) + // NOTE: we add 256 before casting to unsigned to correctly represent negative values + u32::cast_from(i32::cast_from(F::clamp(F::round(value / scale), range_min, range_max)) + 256) +} + +// Would have wrapped symmetric with the same affine kernel but cube doesn't support Option for offset. +#[cube(launch_unchecked)] +pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( + input: &Tensor, + scale: &Tensor, + range_min: f32, + range_max: f32, + output: &mut Tensor, + #[comptime] vectorized: bool, +) { + if ABSOLUTE_POS >= output.len() { + return; + } + + let scale = scale[0]; + + let num_packed = 4; + let mut v_packed = 0; + + if vectorized { + // Assuming a vectorization factor of 4 (equal to the number of values packed) + let value = input[ABSOLUTE_POS]; + let vectorization_factor = vectorization_of(input); + #[unroll] + for i in 0..vectorization_factor { + let v = quantize_symmetric_int8::(value[i], scale, range_min, range_max); + // Shift and combine into u32 + v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); + } + } else { + for i in 0..num_packed { + let v = quantize_symmetric_int8::( + input[ABSOLUTE_POS + i], + scale, + range_min, + range_max, + ); + // Shift and combine into u32 + v_packed |= (v & 0xFF) << (8 * (num_packed - i - 1)); + } + } + + output[ABSOLUTE_POS] = v_packed; +} + +pub(crate) fn quantize_per_tensor( + tensor: JitTensor, + scale: JitTensor, + offset: Option>, +) -> JitTensor +where + R: JitRuntime, + F: JitElement, + I: IntElement, +{ + let num_elems = tensor.shape.num_elements(); + let shape_output = tensor.shape.clone(); + let client = tensor.client.clone(); + // Output tensor contains 4x less elements (four int8 values packed in a single u32) + let handle = client.empty(usize::div_ceil(num_elems, 4) * core::mem::size_of::()); + let output = + JitTensor::new_contiguous(client.clone(), tensor.device.clone(), shape_output, handle); + + // Force vectorization to process 4 quantized values packed for 1 output value + let vectorization_factor: u8 = if num_elems < 4 { 1 } else { 4 }; + let cube_dim = CubeDim::default(); + let cube_count = + calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); + + let dummy_array = [1; D]; + if let Some(offset) = offset { + unsafe { + quantize_per_tensor_affine_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg(vectorization_factor), + // Ignore shape and stride + TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1), + TensorArg::from_raw_parts(&offset.handle, &dummy_array, &dummy_array, 1), + ScalarArg::new(i8::MIN as f32), + ScalarArg::new(i8::MAX as f32), + output.as_tensor_arg(1), + vectorization_factor > 1, + ) + }; + } else { + unsafe { + quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg(vectorization_factor), + // Ignore shape and stride + TensorArg::from_raw_parts(&scale.handle, &dummy_array, &dummy_array, 1), + ScalarArg::new(-i8::MAX as f32), + ScalarArg::new(i8::MAX as f32), + output.as_tensor_arg(1), + vectorization_factor > 1, + ) + }; + } + + output +} + +/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. +pub fn quantize( + tensor: JitTensor, + scheme: &QuantizationScheme, + qparams: JitQuantizationParameters, +) -> QJitTensor +where + R: JitRuntime, + F: FloatElement, + I: IntElement, +{ + let qtensor = match scheme { + QuantizationScheme::PerTensorAffine(dtype) + | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { + QuantizationType::QInt8 => { + quantize_per_tensor(tensor, qparams.scale.clone(), qparams.offset.clone()) + } + }, + }; + + QJitTensor { + qtensor, + scheme: scheme.clone(), + qparams, + } +} diff --git a/crates/burn-jit/src/ops/qtensor.rs b/crates/burn-jit/src/ops/qtensor.rs index d6ccb79ad..7451d61ba 100644 --- a/crates/burn-jit/src/ops/qtensor.rs +++ b/crates/burn-jit/src/ops/qtensor.rs @@ -1,12 +1,47 @@ use std::ops::Range; +use alloc::vec::Vec; use burn_tensor::{ ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme}, - Device, Shape, TensorData, + quantization::{ + QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, + QuantizationStrategy, QuantizationType, + }, + DType, Device, ElementConversion, Shape, TensorData, }; -use crate::{tensor::QJitTensor, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{ + kernel, + tensor::{JitQuantizationParameters, JitTensor, QJitTensor}, + FloatElement, IntElement, JitBackend, JitRuntime, +}; +use cubecl::CubeElement; + +fn pack_i8s_to_u32s(data: &TensorData) -> Vec { + // Shift and combine groups of four 8-bit values into a u32. + // Same as doing this: + // let result = (a_u8 & 0xFF) << 24 | (b_u8 & 0xFF) << 16 | (c_u8 & 0xFF) << 8 | (d_u8 & 0xFF); + data.as_bytes() + .chunks(4) + .map(|x| { + x.iter().enumerate().fold(0u32, |acc, (i, x)| { + acc | (*x as i8 as u32 & 0xFF) << ((3 - i) * 8) + }) + }) + .collect() +} + +/// Create a quantized tensor with packed values (u32). +fn packed_tensor>, const D: usize>( + data: Vec, + shape: S, + device: &R::Device, +) -> JitTensor { + let client = R::client(device); + let buffer = client.create(u32::as_bytes(&data)); + + JitTensor::new_contiguous(client, device.clone(), shape.into(), buffer) +} impl QTensorOps for JitBackend where @@ -15,22 +50,49 @@ where I: IntElement, { fn q_from_data( - _data: TensorData, - _device: &Device, + data: TensorData, + device: &Device, ) -> QuantizedTensor { - todo!() + match data.dtype { + DType::QFloat(strategy) => match strategy { + QuantizationStrategy::PerTensorAffineInt8(q) => { + // Convert quantized values to packed u32s + QJitTensor { + qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device), + scheme: strategy.scheme(), + qparams: JitQuantizationParameters::new( + q.scale.elem(), + Some(q.offset.elem()), + device, + ), + } + } + QuantizationStrategy::PerTensorSymmetricInt8(q) => { + // Convert quantized values to packed u32s + QJitTensor { + qtensor: packed_tensor(pack_i8s_to_u32s(&data), data.shape, device), + scheme: strategy.scheme(), + qparams: JitQuantizationParameters::new(q.scale.elem(), None, device), + } + } + }, + _ => panic!( + "Invalid dtype (expected DType::QFloat, got {:?})", + data.dtype + ), + } } fn quantize( - _tensor: FloatTensor, - _scheme: &QuantizationScheme, - _qparams: QuantizationParametersPrimitive, + tensor: FloatTensor, + scheme: &QuantizationScheme, + qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { - unimplemented!() + kernel::quantization::quantize(tensor, scheme, qparams.into()) } - fn dequantize(_tensor: QuantizedTensor) -> FloatTensor { - unimplemented!() + fn dequantize(tensor: QuantizedTensor) -> FloatTensor { + kernel::quantization::dequantize(tensor) } fn q_shape(tensor: &QuantizedTensor) -> Shape { @@ -42,10 +104,12 @@ where } fn q_to_device( - _tensor: QuantizedTensor, - _device: &Device, + tensor: QuantizedTensor, + device: &Device, ) -> QuantizedTensor { - unimplemented!() + let mut tensor = tensor; + tensor.qtensor = super::to_device(tensor.qtensor, device); + tensor } fn q_reshape( @@ -55,11 +119,43 @@ where QJitTensor { qtensor: super::reshape(tensor.qtensor, shape), scheme: tensor.scheme, + qparams: tensor.qparams, } } - async fn q_into_data(_tensor: QuantizedTensor) -> TensorData { - unimplemented!() + async fn q_into_data(tensor: QuantizedTensor) -> TensorData { + let strategy = tensor.strategy(); + let numel = tensor.qtensor.shape.num_elements(); + let qtensor = kernel::into_contiguous(tensor.qtensor); + + let bytes = qtensor.client.read_async(qtensor.handle.binding()).await; + + // Convert packed bytes to quantized dtype (TensorData can be used with other backends, + // which don't have the prior knowledge of this packed representation) + match &tensor.scheme { + QuantizationScheme::PerTensorAffine(dtype) + | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { + QuantizationType::QInt8 => TensorData::quantized( + u32::from_bytes(&bytes) + .iter() + .enumerate() + .flat_map(|(i, packed)| { + // A single u32 could contain less than four 8-bit values... + let n = core::cmp::min(4, numel - i * 4); + // Extract each 8-bit segment from u32 and cast back to i8 + // Same as doing this (when 4 values are fully packed): + // let a = ((packed >> 24) & 0xFF) as i8; + // let b = ((packed >> 16) & 0xFF) as i8; + // let c = ((packed >> 8) & 0xFF) as i8; + // let d = (packed & 0xFF) as i8; + (0..n).map(move |i| (packed >> ((3 - i) * 8) & 0xFF) as i8) + }) + .collect(), + qtensor.shape, + strategy, + ), + }, + } } fn q_swap_dims( diff --git a/crates/burn-jit/src/tensor/qtensor.rs b/crates/burn-jit/src/tensor/qtensor.rs index 08523f6ce..62736b1a7 100644 --- a/crates/burn-jit/src/tensor/qtensor.rs +++ b/crates/burn-jit/src/tensor/qtensor.rs @@ -1,34 +1,111 @@ -use burn_tensor::quantization::{QTensorPrimitive, QuantizationScheme, QuantizationStrategy}; +use burn_tensor::{ + quantization::{ + AffineQuantization, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, + QuantizationStrategy, QuantizationType, SymmetricQuantization, + }, + read_sync, TensorData, +}; -use crate::JitRuntime; +use crate::{ops::into_data, FloatElement, IntElement, JitBackend, JitRuntime}; use super::JitTensor; /// A quantized tensor primitive. #[derive(Debug)] -pub struct QJitTensor { +pub struct QJitTensor { /// The quantized tensor. - // TODO: implement `JitElement` / `CubeElement` for quantized type + /// Values are stored as multiple packed quantized values in u32. pub qtensor: JitTensor, /// The quantization scheme. pub scheme: QuantizationScheme, + /// The quantization parameters. + pub qparams: JitQuantizationParameters, } -impl QTensorPrimitive for QJitTensor { +impl QTensorPrimitive + for QJitTensor +{ fn scheme(&self) -> &QuantizationScheme { &self.scheme } fn strategy(&self) -> QuantizationStrategy { - todo!() + match &self.scheme { + QuantizationScheme::PerTensorAffine(dtype) => match dtype { + QuantizationType::QInt8 => { + let scale = read_sync(into_data(self.qparams.scale.clone())) + .iter() + .next() + .unwrap(); + let offset = read_sync(into_data(self.qparams.offset.clone().unwrap())) + .iter() + .next() + .unwrap(); + QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init( + scale, offset, + )) + } + }, + QuantizationScheme::PerTensorSymmetric(dtype) => match dtype { + QuantizationType::QInt8 => { + let scale = read_sync(into_data(self.qparams.scale.clone())) + .iter() + .next() + .unwrap(); + QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(scale)) + } + }, + } } } -impl Clone for QJitTensor { +impl Clone + for QJitTensor +{ fn clone(&self) -> Self { Self { qtensor: self.qtensor.clone(), scheme: self.scheme.clone(), + qparams: self.qparams.clone(), + } + } +} + +/// The quantization parameters. +#[derive(Debug)] +pub struct JitQuantizationParameters { + /// The scaling factor. + pub scale: JitTensor, + /// The zero-point offset. + pub offset: Option>, +} + +impl Clone for JitQuantizationParameters { + fn clone(&self) -> Self { + Self { + scale: self.scale.clone(), + offset: self.offset.clone(), + } + } +} + +impl + From>> + for JitQuantizationParameters +{ + fn from(value: QuantizationParametersPrimitive>) -> Self { + JitQuantizationParameters { + scale: value.scale, + offset: value.offset, + } + } +} + +impl JitQuantizationParameters { + pub fn new(scale: F, offset: Option, device: &R::Device) -> Self { + Self { + scale: crate::ops::from_data(TensorData::new(vec![scale], [1]), device), + offset: offset.map(|o| crate::ops::from_data(TensorData::new(vec![o], [1]), device)), } } } diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index d4be968fc..c805d9084 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -16,6 +16,7 @@ mod matmul; mod max_pool2d; mod max_pool2d_backward; mod normal; +mod quantization; mod reduce; mod repeat_dim; mod scatter; @@ -73,6 +74,8 @@ macro_rules! testgen_all { burn_jit::testgen_cat!(); burn_jit::testgen_clamp!(); burn_jit::testgen_unary!(); + + burn_jit::testgen_quantization!(); } } mod jit_fusion { @@ -100,6 +103,14 @@ macro_rules! testgen_jit { burn_tensor::testgen_all!(); burn_autodiff::testgen_all!(); + + // Not all ops are implemented for quantization yet, notably missing: + // `q_swap_dims`, `q_permute`, `q_flip`, `q_gather`, `q_select`, `q_slice`, `q_expand` + // burn_tensor::testgen_quantization!(); + // test quantization + burn_tensor::testgen_calibration!(); + burn_tensor::testgen_scheme!(); + burn_tensor::testgen_quantize!(); }; } diff --git a/crates/burn-jit/src/tests/quantization.rs b/crates/burn-jit/src/tests/quantization.rs new file mode 100644 index 000000000..77cf8dbb9 --- /dev/null +++ b/crates/burn-jit/src/tests/quantization.rs @@ -0,0 +1,82 @@ +#[burn_tensor_testgen::testgen(quantization)] +mod tests { + use super::*; + use burn_tensor::{ + quantization::{QuantizationScheme, QuantizationType}, + Tensor, + }; + + #[test] + fn should_quantize_dequantize_symmetric_single() { + let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); + let input = Tensor::::from_floats([-1.8], &Default::default()); + let input_ref = + Tensor::::from_data(input.to_data(), &Default::default()); + + let output = input.quantize_dynamic(&scheme); + let output_ref = input_ref.quantize_dynamic(&scheme); + + output.to_data().assert_eq(&output_ref.to_data(), false); + + let output = output.dequantize(); + let output_ref = output_ref.dequantize(); + + output.to_data().assert_approx_eq(&output_ref.to_data(), 3); + } + + #[test] + fn should_quantize_dequantize_affine_single() { + let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let input = Tensor::::from_floats([-1.8], &Default::default()); + let input_ref = + Tensor::::from_data(input.to_data(), &Default::default()); + + let output = input.quantize_dynamic(&scheme); + let output_ref = input_ref.quantize_dynamic(&scheme); + + output.to_data().assert_eq(&output_ref.to_data(), false); + + let output = output.dequantize(); + let output_ref = output_ref.dequantize(); + + output.to_data().assert_approx_eq(&output_ref.to_data(), 2); + } + + #[test] + fn should_quantize_dequantize_symmetric_multiple() { + let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); + let input = + Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5, 0.0], &Default::default()); + let input_ref = + Tensor::::from_data(input.to_data(), &Default::default()); + + let output = input.quantize_dynamic(&scheme); + let output_ref = input_ref.quantize_dynamic(&scheme); + + output.to_data().assert_eq(&output_ref.to_data(), false); + + let output = output.dequantize(); + let output_ref = output_ref.dequantize(); + + output.to_data().assert_approx_eq(&output_ref.to_data(), 3); + } + + #[test] + fn should_quantize_dequantize_affine_multiple() { + let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8); + let input = + Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5, 0.0], &Default::default()); + let input_ref = + Tensor::::from_data(input.to_data(), &Default::default()); + + let output = input.quantize_dynamic(&scheme); + let output_ref = input_ref.quantize_dynamic(&scheme); + + output.to_data().assert_eq(&output_ref.to_data(), false); + + let output = output.dequantize(); + let output_ref = output_ref.dequantize(); + + output.to_data().assert_approx_eq(&output_ref.to_data(), 3); + } +} diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 0c952c3ac..8f9420345 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -361,22 +361,21 @@ impl TensorData { DType::Bool => self.assert_eq_elem::(other), DType::QFloat(q) => { // Strict or not, it doesn't make sense to compare quantized data to not quantized data for equality - if let DType::QFloat(q_other) = other.dtype { - assert_eq!( - q, q_other, - "Quantization strategies differ ({:?} != {:?})", - q, q_other - ) + let q_other = if let DType::QFloat(q_other) = other.dtype { + q_other } else { panic!("Quantized data differs from other not quantized data") - } - match q { - QuantizationStrategy::PerTensorAffineInt8(_) => { - self.assert_eq_elem::(other) - } - QuantizationStrategy::PerTensorSymmetricInt8(_) => { - self.assert_eq_elem::(other) - } + }; + match (q, q_other) { + ( + QuantizationStrategy::PerTensorAffineInt8(_), + QuantizationStrategy::PerTensorAffineInt8(_), + ) => self.assert_eq_elem::(other), + ( + QuantizationStrategy::PerTensorSymmetricInt8(_), + QuantizationStrategy::PerTensorSymmetricInt8(_), + ) => self.assert_eq_elem::(other), + _ => panic!("Quantization strategies differ ({:?} != {:?})", q, q_other), } } } diff --git a/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs b/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs index 85a1db1b0..57ad69edb 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/aggregation.rs @@ -194,7 +194,7 @@ mod tests { output .dequantize() .into_data() - .assert_eq(&TensorData::from([0.0]), false); + .assert_approx_eq(&TensorData::from([0.0]), 5); } #[test] diff --git a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs index 699100141..438b973f0 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs @@ -55,15 +55,6 @@ mod tests { #[test] fn should_support_dequantize() { let device = Default::default(); - let tensor = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5], &device); - let scheme = QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8); - let qparams = QuantizationParameters { - scale: Tensor::from_floats([0.014_173_228], &device), - offset: None, - }; - - let x_q = tensor.quantize(&scheme, qparams); - // Quantized [-1.8, -1.0, 0.0, 0.5] let data = TensorData::quantized( vec![-127i8, -71, 0, 35], @@ -97,6 +88,6 @@ mod tests { QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.05882353, 42)), ); - x_q.to_data().assert_eq(&expected, true); + x_q.to_data().assert_eq(&expected, false); } }